From 800acd45ff554dd63f8de365e111c625c8923be6 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sat, 7 Mar 2026 05:47:22 +0800 Subject: [PATCH 01/24] Enhance G2P processing by implementing batch input handling in _g2p function, improving efficiency. Update prepare_onnx_input to utilize caching for tokenization and add optional parameters for character ID mapping and phoneme masks. Refactor G2PWOnnxConverter to streamline model loading and configuration management. --- GPT_SoVITS/text/chinese2.py | 17 ++- GPT_SoVITS/text/g2pw/dataset.py | 107 ++++++++++++++---- GPT_SoVITS/text/g2pw/onnx_api.py | 181 +++++++++++++++++++++++-------- 3 files changed, 233 insertions(+), 72 deletions(-) diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index dcce0d96..acfebfe2 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -180,10 +180,15 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> def _g2p(segments): phones_list = [] word2ph = [] - for seg in segments: + g2pw_batch_results = [] + g2pw_batch_cursor = 0 + processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments] + if is_g2pw: + batch_inputs = [seg for seg in processed_segments if seg] + g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else [] + + for seg in processed_segments: pinyins = [] - # Replace all English words in the sentence - seg = re.sub("[a-zA-Z]+", "", seg) seg_cut = psg.lcut(seg) seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) initials = [] @@ -204,8 +209,10 @@ def _g2p(segments): finals = sum(finals, []) print("pypinyin结果", initials, finals) else: - # g2pw采用整句推理 - pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3) + # g2pw采用整句推理(批量推理,逐句取结果) + if seg: + pinyins = g2pw_batch_results[g2pw_batch_cursor] + g2pw_batch_cursor += 1 pre_word_length = 0 for word, pos in seg_cut: diff --git a/GPT_SoVITS/text/g2pw/dataset.py b/GPT_SoVITS/text/g2pw/dataset.py index ff09cbc2..e464c29a 100644 --- a/GPT_SoVITS/text/g2pw/dataset.py +++ b/GPT_SoVITS/text/g2pw/dataset.py @@ -18,6 +18,7 @@ Credits from typing import Dict from typing import List +from typing import Optional from typing import Tuple import numpy as np @@ -37,6 +38,8 @@ def prepare_onnx_input( use_mask: bool = False, window_size: int = None, max_len: int = 512, + char2id: Optional[Dict[str, int]] = None, + char_phoneme_masks: Optional[Dict[str, List[int]]] = None, ) -> Dict[str, np.array]: if window_size is not None: truncated_texts, truncated_query_ids = _truncate_texts( @@ -48,33 +51,88 @@ def prepare_onnx_input( phoneme_masks = [] char_ids = [] position_ids = [] + tokenized_cache = {} + + if char2id is None: + char2id = {char: idx for idx, char in enumerate(chars)} + if use_mask: + if char_phoneme_masks is None: + char_phoneme_masks = { + char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))] + for char in char2phonemes + } + else: + full_phoneme_mask = [1] * len(labels) for idx in range(len(texts)): text = (truncated_texts if window_size else texts)[idx].lower() query_id = (truncated_query_ids if window_size else query_ids)[idx] - try: - tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) - except Exception: - print(f'warning: text "{text}" is invalid') - return {} + cached = tokenized_cache.get(text) + if cached is None: + try: + tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} - text, query_id, tokens, text2token, token2text = _truncate( - max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text - ) + if len(tokens) <= max_len - 2: + processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) + cached = { + "is_short": True, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + "input_id": shared_input_id, + "token_type_id": shared_token_type_id, + "attention_mask": shared_attention_mask, + } + else: + cached = { + "is_short": False, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + } + tokenized_cache[text] = cached - processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + if cached["is_short"]: + text_for_query = text + query_id_for_query = query_id + text2token_for_query = cached["text2token"] + input_id = cached["input_id"] + token_type_id = cached["token_type_id"] + attention_mask = cached["attention_mask"] + else: + ( + text_for_query, + query_id_for_query, + tokens_for_query, + text2token_for_query, + _token2text_for_query, + ) = _truncate( + max_len=max_len, + text=text, + query_id=query_id, + tokens=cached["tokens"], + text2token=cached["text2token"], + token2text=cached["token2text"], + ) + processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"] + input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) - token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) - attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - - query_char = text[query_id] - phoneme_mask = ( - [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels) - ) - char_id = chars.index(query_char) - position_id = text2token[query_id] + 1 # [CLS] token locate at first place + query_char = text_for_query[query_id_for_query] + if use_mask: + phoneme_mask = char_phoneme_masks[query_char] + else: + phoneme_mask = full_phoneme_mask + char_id = char2id[query_char] + position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place input_ids.append(input_id) token_type_ids.append(token_type_id) @@ -83,10 +141,15 @@ def prepare_onnx_input( char_ids.append(char_id) position_ids.append(position_id) + max_token_length = max(len(seq) for seq in input_ids) + + def _pad_sequences(sequences, pad_value=0): + return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences] + outputs = { - "input_ids": np.array(input_ids).astype(np.int64), - "token_type_ids": np.array(token_type_ids).astype(np.int64), - "attention_masks": np.array(attention_masks).astype(np.int64), + "input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64), + "token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64), + "attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64), "phoneme_masks": np.array(phoneme_masks).astype(np.float32), "char_ids": np.array(char_ids).astype(np.int64), "position_ids": np.array(position_ids).astype(np.int64), diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 1d5e4231..5fcf1ae2 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple import numpy as np import onnxruntime import requests -import torch from opencc import OpenCC from pypinyin import Style, pinyin from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -22,9 +21,8 @@ from .utils import load_config onnxruntime.set_default_logger_severity(3) try: onnxruntime.preload_dlls() -except: +except Exception: pass - # traceback.print_exc() warnings.filterwarnings("ignore") model_version = "1.1" @@ -55,6 +53,24 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis return all_preds, all_confidences +def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]: + for candidate_dir in candidate_dirs: + if not candidate_dir: + continue + json_path = os.path.join(candidate_dir, filename) + if os.path.exists(json_path): + with open(json_path, "r", encoding="utf-8") as fr: + return json.load(fr) + raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}") + + +def _find_first_existing_file(*paths: str) -> str: + for path in paths: + if path and os.path.exists(path): + return path + raise FileNotFoundError(f"Files not found: {paths}") + + def download_and_decompress(model_dir: str = "G2PWModel/"): if not os.path.exists(model_dir): parent_directory = os.path.dirname(model_dir) @@ -62,7 +78,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): extract_dir = os.path.join(parent_directory, "G2PWModel_1.1") extract_dir_new = os.path.join(parent_directory, "G2PWModel") print("Downloading g2pw model...") - modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip" + modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" with requests.get(modelscope_url, stream=True) as r: r.raise_for_status() with open(zip_dir, "wb") as f: @@ -79,7 +95,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): return model_dir -class G2PWOnnxConverter: +class _G2PWBaseOnnxConverter: def __init__( self, model_dir: str = "G2PWModel/", @@ -87,33 +103,16 @@ class G2PWOnnxConverter: model_source: str = None, enable_non_tradional_chinese: bool = False, ): - uncompress_path = download_and_decompress(model_dir) - - sess_options = onnxruntime.SessionOptions() - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL - sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0 - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], - ) - else: - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CPUExecutionProvider"], - ) - self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True) + self.model_dir = download_and_decompress(model_dir) + self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True) self.model_source = model_source if model_source else self.config.model_source self.enable_opencc = enable_non_tradional_chinese - self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) - polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt") - monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt") + polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt") + monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt") + self.polyphonic_chars = [ line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n") ] @@ -149,31 +148,45 @@ class G2PWOnnxConverter: ) self.chars = sorted(list(self.char2phonemes.keys())) + self.char2id = {char: idx for idx, char in enumerate(self.chars)} + self.char_phoneme_masks = ( + { + char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))] + for char in self.char2phonemes + } + if self.config.use_mask + else None + ) self.polyphonic_chars_new = set(self.chars) for char in self.non_polyphonic: - if char in self.polyphonic_chars_new: - self.polyphonic_chars_new.remove(char) + self.polyphonic_chars_new.discard(char) self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars} for char in self.non_monophonic: - if char in self.monophonic_chars_dict: - self.monophonic_chars_dict.pop(char) + self.monophonic_chars_dict.pop(char, None) - self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"] + default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel")) + candidate_asset_dirs = [self.model_dir, default_asset_dir] + self.bopomofo_convert_dict = _load_json_from_candidates( + "bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs + ) + self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs) - with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr: - self.bopomofo_convert_dict = json.load(fr) self.style_convert_func = { "bopomofo": lambda x: x, "pinyin": self._convert_bopomofo_to_pinyin, }[style] - with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr: - self.char_bopomofo_dict = json.load(fr) - if self.enable_opencc: self.cc = OpenCC("s2tw") + self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in { + "1", + "true", + "yes", + "y", + "on", + } def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] @@ -181,9 +194,8 @@ class G2PWOnnxConverter: component = self.bopomofo_convert_dict.get(bopomofo[:-1]) if component: return component + tone - else: - print(f'Warning: "{bopomofo}" cannot convert to pinyin') - return None + print(f'Warning: "{bopomofo}" cannot convert to pinyin') + return None def __call__(self, sentences: List[str]) -> List[List[str]]: if isinstance(sentences, str): @@ -199,10 +211,9 @@ class G2PWOnnxConverter: texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: - # sentences no polyphonic words return partial_results - onnx_input = prepare_onnx_input( + model_input = prepare_onnx_input( tokenizer=self.tokenizer, labels=self.labels, char2phonemes=self.char2phonemes, @@ -211,9 +222,18 @@ class G2PWOnnxConverter: query_ids=query_ids, use_mask=self.config.use_mask, window_size=None, + char2id=self.char2id, + char_phoneme_masks=self.char_phoneme_masks, ) - preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels) + if not model_input: + return partial_results + + if self.enable_sentence_dedup: + preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts) + else: + preds, _confidences = self._predict(model_input=model_input) + if self.config.use_char_phoneme: preds = [pred.split(" ")[1] for pred in preds] @@ -226,7 +246,6 @@ class G2PWOnnxConverter: def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: texts, query_ids, sent_ids, partial_results = [], [], [], [] for sent_id, sent in enumerate(sentences): - # pypinyin works well for Simplified Chinese than Traditional Chinese sent_s = tranditional_to_simplified(sent) pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) partial_result = [None] * len(sent) @@ -239,9 +258,81 @@ class G2PWOnnxConverter: partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: partial_result[i] = pypinyin_result[i][0] - # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) else: partial_result[i] = pypinyin_result[i][0] partial_results.append(partial_result) return texts, query_ids, sent_ids, partial_results + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + raise NotImplementedError + + def _predict_with_sentence_dedup( + self, model_input: Dict[str, Any], texts: List[str] + ) -> Tuple[List[str], List[float]]: + if len(texts) <= 1: + return self._predict(model_input=model_input) + + grouped_indices: Dict[str, List[int]] = {} + for idx, text in enumerate(texts): + grouped_indices.setdefault(text, []).append(idx) + + if all(len(indices) == 1 for indices in grouped_indices.values()): + return self._predict(model_input=model_input) + + preds: List[str] = [""] * len(texts) + confidences: List[float] = [0.0] * len(texts) + for indices in grouped_indices.values(): + group_input = {name: value[indices] for name, value in model_input.items()} + if len(indices) > 1: + for name in ("input_ids", "token_type_ids", "attention_masks"): + group_input[name] = group_input[name][:1] + + group_preds, group_confidences = self._predict(model_input=group_input) + for output_idx, pred, confidence in zip(indices, group_preds, group_confidences): + preds[output_idx] = pred + confidences[output_idx] = confidence + + return preds, confidences + + +class G2PWOnnxConverter(_G2PWBaseOnnxConverter): + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + super().__init__( + model_dir=model_dir, + style=style, + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL + sess_options.intra_op_num_threads = 2 + + onnx_path = _find_first_existing_file( + os.path.join(self.model_dir, "g2pW.onnx"), + os.path.join(self.model_dir, "g2pw.onnx"), + ) + + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + else: + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels) From b250e62402893b4f355d58599cd946150880ef7f Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sun, 8 Mar 2026 03:01:20 +0800 Subject: [PATCH 02/24] Enhance G2PW model input handling by introducing polyphonic context character support and updating the data preparation method to return additional query IDs. This improves the processing of polyphonic characters in sentences. --- GPT_SoVITS/text/g2pw/onnx_api.py | 37 ++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 5fcf1ae2..3c2b0169 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -187,6 +187,8 @@ class _G2PWBaseOnnxConverter: "y", "on", } + # 聚焦到多音字附近上下文,默认左右各16字;设为0表示关闭裁剪(整句)。 + self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16"))) def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] @@ -209,7 +211,7 @@ class _G2PWBaseOnnxConverter: translated_sentences.append(translated_sent) sentences = translated_sentences - texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) + texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: return partial_results @@ -219,7 +221,7 @@ class _G2PWBaseOnnxConverter: char2phonemes=self.char2phonemes, chars=self.chars, texts=texts, - query_ids=query_ids, + query_ids=model_query_ids, use_mask=self.config.use_mask, window_size=None, char2id=self.char2id, @@ -238,22 +240,23 @@ class _G2PWBaseOnnxConverter: preds = [pred.split(" ")[1] for pred in preds] results = partial_results - for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): + for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds): results[sent_id][query_id] = self.style_convert_func(pred) return results - def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: - texts, query_ids, sent_ids, partial_results = [], [], [], [] + def _prepare_data( + self, sentences: List[str] + ) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]: + texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], [] for sent_id, sent in enumerate(sentences): sent_s = tranditional_to_simplified(sent) pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) partial_result = [None] * len(sent) + polyphonic_indices: List[int] = [] for i, char in enumerate(sent): if char in self.polyphonic_chars_new: - texts.append(sent) - query_ids.append(i) - sent_ids.append(sent_id) + polyphonic_indices.append(i) elif char in self.monophonic_chars_dict: partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: @@ -261,8 +264,24 @@ class _G2PWBaseOnnxConverter: else: partial_result[i] = pypinyin_result[i][0] + if polyphonic_indices: + if self.polyphonic_context_chars > 0: + left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars) + right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1) + sent_for_predict = sent[left:right] + query_offset = left + else: + sent_for_predict = sent + query_offset = 0 + + for index in polyphonic_indices: + texts.append(sent_for_predict) + model_query_ids.append(index - query_offset) + result_query_ids.append(index) + sent_ids.append(sent_id) + partial_results.append(partial_result) - return texts, query_ids, sent_ids, partial_results + return texts, model_query_ids, result_query_ids, sent_ids, partial_results def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: raise NotImplementedError From 30a4557d8db2ebbac91dc264a063cc7175172932 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sun, 8 Mar 2026 23:08:27 +0800 Subject: [PATCH 03/24] Implement last inference statistics tracking in Text2SemanticDecoder and enhance TTS class with prompt semantic extraction. This includes methods for setting and retrieving inference stats, as well as improvements to audio processing and feature extraction in TTS. --- GPT_SoVITS/AR/models/t2s_model.py | 76 +++++++++++++++++++++++++++++-- GPT_SoVITS/TTS_infer_pack/TTS.py | 38 ++++++++++++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ac905f4b..f55b7508 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module): blocks.append(block) self.t2s_transformer = T2STransformer(self.num_layers, blocks) + self.last_infer_stats = {} + + def _set_last_infer_stats(self, stats): + self.last_infer_stats = stats + + def get_last_infer_stats(self): + return dict(self.last_infer_stats) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): + requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True)) if prompts is None: + self._set_last_infer_stats( + { + "infer_mode": "batch_infer_prompt_free_fallback", + "requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": 0, + "generated_token_count_list": [], + } + ) print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") return self.infer_panel_naive_batched( x, @@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module): ) max_len = kwargs.get("max_len", x_lens.max()) + enable_mask_free_fastpath = requested_enable_mask_free_fastpath x_list = [] for x_item, bert_item in zip(x, bert_feature): # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) @@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module): y_list = [None] * y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None] * y.shape[0] + decode_attn_mask = attn_mask + prefill_after_mask_all_visible = None + fastpath_hit = False for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token( + xy_pos, k_cache, v_cache, decode_attn_mask + ) logits = self.ar_predict_layer(xy_dec[:, -1]) if idx == 0: attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False) + prefill_after_mask_all_visible = not attn_mask.any().item() + if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible: + decode_attn_mask = None + fastpath_hit = True + else: + decode_attn_mask = attn_mask else: - attn_mask = F.pad(attn_mask, (0, 1), value=False) + if decode_attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1), value=False) + decode_attn_mask = attn_mask if idx < 11: ###至少预测出10个token不然不给停止(0.4s) logits = logits[:, :-1] @@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module): if reserved_idx_of_batch_for_y is not None: # index = torch.LongTensor(batch_idx_map).to(y.device) y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + if decode_attn_mask is not None: + attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + decode_attn_mask = attn_mask if k_cache is not None: for i in range(len(k_cache)): k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) @@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module): if idx_list[i] is None: idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替 + self._set_last_infer_stats( + { + "infer_mode": "batch_infer", + "requested_enable_mask_free_fastpath": enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": prefill_after_mask_all_visible, + "fastpath_hit": fastpath_hit, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + "max_len": int(max_len), + } + ) if ref_free: return y_list, [0] * x.shape[0] # print(idx_list) @@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module): y_list.append(y[0]) idx_list.append(idx) + self._set_last_infer_stats( + { + "infer_mode": "naive_batched", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + } + ) return y_list, idx_list def infer_panel_naive( @@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module): if not streaming_mode: + generated_token_count = max(int(y.shape[1] - prefix_len), 0) + self._set_last_infer_stats( + { + "infer_mode": "naive", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(x.shape[0]), + "prefill_after_mask_all_visible": True if prompts is not None else None, + "fastpath_hit": True if prompts is not None else False, + "generated_token_count": generated_token_count, + "generated_token_count_list": [generated_token_count], + } + ) if ref_free: yield y, 0 yield y, idx diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c8344b0..e1efd973 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1227,6 +1227,9 @@ class TTS: ###### inference ###### t_34 = 0.0 t_45 = 0.0 + t2s_observe_batch_count = 0 + t2s_observe_fastpath_hits = 0 + t2s_observe_generated_tokens = 0 audio = [] is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] @@ -1280,6 +1283,29 @@ class TTS: ) t4 = time.perf_counter() t_34 += t4 - t3 + if hasattr(self.t2s_model.model, "get_last_infer_stats"): + t2s_stats = self.t2s_model.model.get_last_infer_stats() + if t2s_stats: + generated_token_count = int(t2s_stats.get("generated_token_count", 0)) + t2s_total_ms = (t4 - t3) * 1000.0 + avg_decode_ms_per_token = ( + t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0 + ) + t2s_observe_batch_count += 1 + t2s_observe_generated_tokens += generated_token_count + if bool(t2s_stats.get("fastpath_hit", False)): + t2s_observe_fastpath_hits += 1 + print( + "[t2s_observe] " + f"mode={t2s_stats.get('infer_mode')} " + f"batch_size={t2s_stats.get('batch_size')} " + f"tokens={generated_token_count} " + f"t2s_ms={t2s_total_ms:.3f} " + f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} " + f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} " + f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} " + f"fastpath_hit={t2s_stats.get('fastpath_hit')}" + ) batch_audio_fragment = [] @@ -1500,6 +1526,18 @@ class TTS: if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + if t2s_observe_batch_count > 0: + request_avg_decode_ms_per_token = ( + (t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0 + ) + print( + "[t2s_request_observe] " + f"batches={t2s_observe_batch_count} " + f"fastpath_hits={t2s_observe_fastpath_hits} " + f"generated_tokens={t2s_observe_generated_tokens} " + f"t2s_total_ms={t_34 * 1000.0:.3f} " + f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}" + ) if len(audio) == 0: yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return From dc37b0b9ef9f29b852fa37e6cd661369b60d6f86 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 00:22:59 +0800 Subject: [PATCH 04/24] Add WebAPI documentation and implement TTS API with endpoints for text-to-speech inference, control commands, and model switching. Enhance TTS class with methods for extracting prompt semantics and reference audio specifications. Introduce a scheduler prototype for managing T2S requests. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 190 +++- GPT_SoVITS/TTS_infer_pack/__init__.py | 12 +- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 631 +++++++++++ api_v3.py | 1170 ++++++++++++++++++++ tools/t2s_scheduler_prototype.py | 180 +++ 5 files changed, 2147 insertions(+), 36 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py create mode 100644 api_v3.py create mode 100644 tools/t2s_scheduler_prototype.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index e1efd973..bd4953df 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -759,21 +759,35 @@ class TTS: self._set_ref_spec(ref_audio_path) self._set_ref_audio_path(ref_audio_path) - def _set_ref_audio_path(self, ref_audio_path): - self.prompt_cache["ref_audio_path"] = ref_audio_path + def extract_prompt_semantic(self, ref_wav_path: str): + zero_wav = np.zeros( + int(self.configs.sampling_rate * 0.3), + dtype=np.float16 if self.configs.is_half else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(self.configs.device) + zero_wav_torch = zero_wav_torch.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + zero_wav_torch = zero_wav_torch.half() - def _set_ref_spec(self, ref_audio_path): - spec_audio = self._get_ref_spec(ref_audio_path) - if self.prompt_cache["refer_spec"] in [[], None]: - self.prompt_cache["refer_spec"] = [spec_audio] - else: - self.prompt_cache["refer_spec"][0] = spec_audio + wav16k = torch.cat([wav16k, zero_wav_torch]) + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( + 1, 2 + ) # .float() + codes = self.vits_model.extract_latent(hubert_feature) - def _get_ref_spec(self, ref_audio_path): + prompt_semantic = codes[0, 0].to(self.configs.device) + return prompt_semantic + + def extract_ref_spec(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) raw_audio = raw_audio.to(self.configs.device).float() - self.prompt_cache["raw_audio"] = raw_audio - self.prompt_cache["raw_sr"] = raw_sr if raw_sr != self.configs.sampling_rate: audio = raw_audio.to(self.configs.device) @@ -804,33 +818,30 @@ class TTS: audio = audio.half() else: audio = None + return spec, audio, raw_audio, raw_sr + + def extract_text_features(self, text: str, language: str): + return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + + def _set_ref_audio_path(self, ref_audio_path): + self.prompt_cache["ref_audio_path"] = ref_audio_path + + def _set_ref_spec(self, ref_audio_path): + spec_audio = self._get_ref_spec(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [spec_audio] + else: + self.prompt_cache["refer_spec"][0] = spec_audio + + def _get_ref_spec(self, ref_audio_path): + spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path) + self.prompt_cache["raw_audio"] = raw_audio + self.prompt_cache["raw_sr"] = raw_sr return spec, audio def _set_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - self.prompt_cache["prompt_semantic"] = prompt_semantic + prompt_semantic = self.extract_prompt_semantic(ref_wav_path) + self.prompt_cache["prompt_semantic"] = prompt_semantic def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): seq = sequences[0] @@ -1701,6 +1712,115 @@ class TTS: return audio + def using_vocoder_synthesis_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_audio_spec: torch.Tensor, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device) + refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device) + + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio = raw_audio.to(self.configs.device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + tgt_sr = 24000 if self.configs.version == "v3" else 32000 + if raw_sr != tgt_sr: + ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device) + + mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4 + mel2 = mel_spec_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + T_ref = self.vocoder_configs["T_ref"] + T_chunk = self.vocoder_configs["T_chunk"] + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min + + mel2 = mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + + cfm_res = self.vits_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + + with torch.inference_mode(): + wav_gen = self.vocoder(cfm_res) + audio = wav_gen[0][0] + + return audio + + def synthesize_audio_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_spec: tuple, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + refer_audio_spec, audio_tensor = refer_spec + if not self.configs.use_vocoder: + refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)] + sv_emb = None + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) + sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device) + return self.vits_model.decode( + semantic_tokens, + phones, + refer_audio_spec_list, + speed=speed, + sv_emb=sv_emb, + ).detach()[0, 0, :] + + return self.using_vocoder_synthesis_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_audio_spec=refer_audio_spec, + raw_audio=raw_audio, + raw_sr=raw_sr, + speed=speed, + sample_steps=sample_steps, + ) + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/__init__.py b/GPT_SoVITS/TTS_infer_pack/__init__.py index 8579a632..09a257b2 100644 --- a/GPT_SoVITS/TTS_infer_pack/__init__.py +++ b/GPT_SoVITS/TTS_infer_pack/__init__.py @@ -1 +1,11 @@ -from . import TTS, text_segmentation_method +from __future__ import annotations + +import importlib + +__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"] + + +def __getattr__(name: str): + if name in __all__: + return importlib.import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py new file mode 100644 index 00000000..e94a72c7 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -0,0 +1,631 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F + +from AR.models.utils import make_pad_mask_left, sample + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +@dataclass +class SchedulerRequestSpec: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int = 0 + + +@dataclass +class T2SRequestState: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + norm_prompt_text: str + norm_text: str + phones: torch.LongTensor + prompt_phones: torch.LongTensor + all_phones: torch.LongTensor + all_bert_features: torch.Tensor + prompt_semantic: torch.LongTensor + refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + raw_audio: torch.Tensor + raw_sr: int + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int + prepare_profile: Dict[str, float] + + +@dataclass +class T2SRunningRequest: + state: T2SRequestState + y_sequence: torch.LongTensor + prefix_len: int + decode_attn_mask: torch.Tensor + k_cache: List[torch.Tensor] + v_cache: List[torch.Tensor] + step_idx: int + + +@dataclass +class T2SFinishedItem: + request_id: str + semantic_tokens: torch.LongTensor + finish_idx: int + finish_reason: str + + +@dataclass +class T2SActiveBatch: + request_ids: List[str] + states: List[T2SRequestState] + x: torch.Tensor + x_lens: torch.LongTensor + y_sequences: List[torch.LongTensor] + prefix_lens: torch.LongTensor + xy_pos: torch.Tensor + prefill_attn_mask: torch.Tensor + decode_attn_mask: Optional[torch.Tensor] + k_cache: Optional[List[torch.Tensor]] + v_cache: Optional[List[torch.Tensor]] + step_idx: int + prefill_done: bool + + +def normalize_sentence(text: str, language: str) -> str: + text = text.strip("\n").strip() + if not text: + return text + if text[-1] not in {",", ".", "?", "!", ",", "。", "?", "!", "…", ";", ";", ":"}: + text += "。" if language != "en" else "." + return text + + +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + device = tts.configs.device + prepare_start = time.perf_counter() + _sync_device(device) + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + + _sync_device(device) + text_features_start = time.perf_counter() + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + _sync_device(device) + text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + if phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + + _sync_device(device) + prompt_semantic_start = time.perf_counter() + prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + _sync_device(device) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + + _sync_device(device) + ref_spec_start = time.perf_counter() + spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) + _sync_device(device) + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + + _sync_device(device) + tensorize_start = time.perf_counter() + phones_tensor = torch.LongTensor(phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device) + all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to( + dtype=tts.precision, device=tts.configs.device + ) + _sync_device(device) + tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 + + _sync_device(device) + prepare_profile = { + "prompt_text_features_ms": prompt_text_features_ms, + "text_features_ms": text_features_ms, + "prompt_semantic_ms": prompt_semantic_ms, + "ref_spec_ms": ref_spec_ms, + "tensorize_ms": tensorize_ms, + "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, + "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, + } + return T2SRequestState( + request_id=spec.request_id, + ref_audio_path=spec.ref_audio_path, + prompt_text=prompt_text, + prompt_lang=spec.prompt_lang, + text=text, + text_lang=spec.text_lang, + norm_prompt_text=prompt_norm_text, + norm_text=norm_text, + phones=phones_tensor, + prompt_phones=prompt_phones_tensor, + all_phones=all_phones, + all_bert_features=all_bert_features, + prompt_semantic=prompt_semantic, + refer_spec=(spec_audio, audio_16k), + raw_audio=raw_audio, + raw_sr=int(raw_sr), + top_k=spec.top_k, + top_p=spec.top_p, + temperature=spec.temperature, + repetition_penalty=spec.repetition_penalty, + early_stop_num=spec.early_stop_num, + ready_step=spec.ready_step, + prepare_profile=prepare_profile, + ) + + +def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: + if hidden.shape[0] >= target_len: + return hidden + return F.pad(hidden, (0, 0, target_len - hidden.shape[0], 0), value=0) + + +def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: torch.device) -> None: + required_len = max_position + 1 + if model.ar_audio_position.pe is not None and model.ar_audio_position.pe.size(1) >= required_len: + if model.ar_audio_position.pe.dtype != dtype or model.ar_audio_position.pe.device != device: + model.ar_audio_position.pe = model.ar_audio_position.pe.to(dtype=dtype, device=device) + return + model.ar_audio_position.extend_pe( + torch.zeros(1, required_len, model.ar_audio_position.embedding_dim, device=device, dtype=dtype) + ) + + +def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: + x_items: List[torch.Tensor] = [] + y_pos_items: List[torch.Tensor] = [] + x_lens: List[int] = [] + prefix_lens: List[int] = [] + y_sequences: List[torch.LongTensor] = [] + + for state in states: + text_emb = model.ar_text_embedding(state.all_phones.unsqueeze(0)) + bert_proj = model.bert_proj(state.all_bert_features.transpose(0, 1).unsqueeze(0)) + x_pos = model.ar_text_position(text_emb + bert_proj).squeeze(0) + y_emb = model.ar_audio_embedding(state.prompt_semantic.unsqueeze(0)) + y_pos = model.ar_audio_position(y_emb).squeeze(0) + x_items.append(x_pos) + y_pos_items.append(y_pos) + x_lens.append(x_pos.shape[0]) + prefix_lens.append(y_pos.shape[0]) + y_sequences.append(state.prompt_semantic.clone()) + + max_x_len = max(x_lens) + max_prefix_len = max(prefix_lens) + x_batch = torch.stack([_left_pad_hidden(item, max_x_len) for item in x_items], dim=0) + y_pos_batch = torch.stack([_left_pad_hidden(item, max_prefix_len) for item in y_pos_items], dim=0) + xy_pos = torch.cat([x_batch, y_pos_batch], dim=1) + + device = x_batch.device + x_lens_tensor = torch.LongTensor(x_lens).to(device) + prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) + batch_size = len(states) + src_len = max_x_len + max_prefix_len + + x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) + y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) + padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1) + x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) + y_mask = F.pad( + torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), + (max_x_len, 0), + value=False, + ) + causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1) + padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1) + attn_mask = causal_mask.logical_or(padding_mask) + attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool() + + return T2SActiveBatch( + request_ids=[state.request_id for state in states], + states=list(states), + x=x_batch, + x_lens=x_lens_tensor, + y_sequences=y_sequences, + prefix_lens=prefix_lens_tensor, + xy_pos=xy_pos, + prefill_attn_mask=attn_mask, + decode_attn_mask=None, + k_cache=None, + v_cache=None, + step_idx=0, + prefill_done=False, + ) + + +def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> torch.Tensor: + last_tokens = torch.stack([seq[-1:] for seq in y_sequences], dim=0) + y_emb = model.ar_audio_embedding(last_tokens) + position_ids = torch.LongTensor([int(seq.shape[0] - 1) for seq in y_sequences]).to(y_emb.device) + _ensure_audio_pe(model, int(position_ids.max().item()), y_emb.dtype, y_emb.device) + pos_emb = model.ar_audio_position.pe[0].index_select(0, position_ids).unsqueeze(1) + return y_emb * model.ar_audio_position.x_scale + model.ar_audio_position.alpha * pos_emb.to( + dtype=y_emb.dtype, device=y_emb.device + ) + + +def _sample_per_request( + model: Any, + active_batch: T2SActiveBatch, + logits: torch.Tensor, + max_steps: int, +) -> Tuple[List[T2SFinishedItem], List[int], List[torch.LongTensor]]: + finished_items: List[T2SFinishedItem] = [] + keep_indices: List[int] = [] + updated_sequences: List[torch.LongTensor] = [] + + step_idx = active_batch.step_idx + for batch_index, state in enumerate(active_batch.states): + logits_i = logits[batch_index : batch_index + 1].clone() + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: + finish_reason = "early_stop" + elif step_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=step_idx, + finish_reason=finish_reason, + ) + ) + else: + keep_indices.append(batch_index) + updated_sequences.append(new_history) + + return finished_items, keep_indices, updated_sequences + + +def decode_one_step( + model: Any, + active_batch: T2SActiveBatch, + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not active_batch.prefill_done: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( + active_batch.xy_pos, active_batch.prefill_attn_mask, None + ) + active_batch.decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + active_batch.prefill_done = True + else: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( + active_batch.xy_pos, + active_batch.k_cache, + active_batch.v_cache, + active_batch.decode_attn_mask, + ) + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False) + + logits = model.ar_predict_layer(xy_dec[:, -1]) + if active_batch.step_idx < 11: + logits = logits[:, :-1] + + finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) + if len(keep_indices) == 0: + return None, finished_items + + device = logits.device + keep_tensor = torch.LongTensor(keep_indices).to(device) + active_batch.request_ids = [active_batch.request_ids[i] for i in keep_indices] + active_batch.states = [active_batch.states[i] for i in keep_indices] + active_batch.y_sequences = updated_sequences + active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if active_batch.k_cache is not None and active_batch.v_cache is not None: + for cache_index in range(len(active_batch.k_cache)): + active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) + active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + active_batch.step_idx += 1 + return active_batch, finished_items + + +def run_scheduler_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + return run_scheduler_continuous(model, states, max_steps=max_steps) + + +def _pad_cache_left(cache: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - cache.shape[1] + if pad_len <= 0: + return cache + return F.pad(cache, (0, 0, pad_len, 0), value=0) + + +def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - mask.shape[-1] + if pad_len <= 0: + return mask + return F.pad(mask, (pad_len, 0), value=True) + + +def run_prefill_step( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not states: + return [], [] + + active_batch = build_prefill_batch(model, states) + xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) + decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + running_requests: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, state in enumerate(states): + logits_i = logits[batch_index : batch_index + 1].clone() + if 0 < 11: + logits_i = logits_i[:, :-1] + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + prefix_len = int(active_batch.prefix_lens[batch_index].item()) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - prefix_len) > state.early_stop_num: + finish_reason = "early_stop" + elif 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=0, + finish_reason=finish_reason, + ) + ) + continue + + real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len + request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + + running_requests.append( + T2SRunningRequest( + state=state, + y_sequence=new_history, + prefix_len=prefix_len, + decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=1, + ) + ) + + return running_requests, finished_items + + +def _build_decode_batch_from_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) + max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) + max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests) + num_layers = len(running_requests[0].k_cache) + + batched_k_cache: List[torch.Tensor] = [] + batched_v_cache: List[torch.Tensor] = [] + for layer_index in range(num_layers): + batched_k_cache.append( + torch.cat([_pad_cache_left(item.k_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + batched_v_cache.append( + torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests], + dim=0, + ) + return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask + + +def run_decode_step_for_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not running_requests: + return [], [] + + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = _build_decode_batch_from_running( + model, running_requests + ) + xy_dec, next_k_cache, next_v_cache = model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + next_running: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, running_request in enumerate(running_requests): + current_idx = running_request.step_idx + logits_i = logits[batch_index : batch_index + 1].clone() + if current_idx < 11: + logits_i = logits_i[:, :-1] + sampled = sample( + logits_i, + running_request.y_sequence.unsqueeze(0), + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + repetition_penalty=running_request.state.repetition_penalty, + temperature=running_request.state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if running_request.state.early_stop_num != -1 and (new_history.shape[0] - running_request.prefix_len) > running_request.state.early_stop_num: + finish_reason = "early_stop" + elif current_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=running_request.state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=current_idx, + finish_reason=finish_reason, + ) + ) + continue + + real_next_kv_len = running_request.k_cache[0].shape[1] + 1 + request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + next_running.append( + T2SRunningRequest( + state=running_request.state, + y_sequence=new_history, + prefix_len=running_request.prefix_len, + decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=current_idx + 1, + ) + ) + + return next_running, finished_items + + +def run_scheduler_continuous( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) + running_requests: List[T2SRunningRequest] = [] + finished: List[T2SFinishedItem] = [] + current_tick = 0 + + while pending or running_requests: + admitted: List[T2SRequestState] = [] + while pending and pending[0].ready_step <= current_tick: + admitted.append(pending.pop(0)) + + admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps) + finished.extend(admitted_finished) + + if running_requests: + running_requests, step_finished = run_decode_step_for_running( + model, + running_requests, + max_steps=max_steps, + ) + finished.extend(step_finished) + + running_requests.extend(admitted_running) + + if not running_requests and pending: + current_tick = max(current_tick + 1, pending[0].ready_step) + continue + + current_tick += 1 + + finished.sort(key=lambda item: item.request_id) + return finished diff --git a/api_v3.py b/api_v3.py new file mode 100644 index 00000000..9d250119 --- /dev/null +++ b/api_v3.py @@ -0,0 +1,1170 @@ +""" +# WebAPI文档 + +` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` + +## 执行参数: + `-a` - `绑定地址, 默认"127.0.0.1"` + `-p` - `绑定端口, 默认9880` + `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"` + +## 调用: + +### 推理 + +endpoint: `/tts` +GET: +``` +http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true +``` + +POST: +```json +{ + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: +``` +http://127.0.0.1:9880/control?command=restart +``` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + + +### 切换GPT模型 + +endpoint: `/set_gpt_weights` + +GET: +``` +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +``` +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 切换Sovits模型 + +endpoint: `/set_sovits_weights` + +GET: +``` +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +``` + +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + +""" + +import asyncio +import os +import sys +import time +import traceback +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Generator, List, Union + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import argparse +import subprocess +import wave +import signal +import numpy as np +import soundfile as sf +import torch +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from io import BytesIO +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + SchedulerRequestSpec, + T2SFinishedItem, + T2SRunningRequest, + T2SRequestState, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, + run_scheduler_continuous, +) +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from pydantic import BaseModel +import threading + +# print(sys.path) +i18n = I18nAuto() +cut_method_names = get_cut_method_names() + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") +parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") +args = parser.parse_args() +config_path = args.tts_config +# device = args.device +port = args.port +host = args.bind_addr +argv = sys.argv + +if config_path in [None, ""]: + config_path = "GPT-SoVITS/configs/tts_infer.yaml" + +tts_config = TTS_Config(config_path) +print(tts_config) +tts_pipeline = TTS(tts_config) + +APP = FastAPI() + + +class TTS_Request(BaseModel): + text: str = None + text_lang: str = None + ref_audio_path: str = None + aux_ref_audio_paths: list = None + prompt_lang: str = None + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = True + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: Union[bool, int] = False + parallel_infer: bool = True + repetition_penalty: float = 1.35 + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + + +class Scheduler_Debug_Request_Item(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + + +class Scheduler_Debug_Request(BaseModel): + requests: List[Scheduler_Debug_Request_Item] + max_steps: int = 1500 + seed: int = -1 + + +class Scheduler_Submit_Request(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + speed_factor: float = 1.0 + sample_steps: int = 32 + media_type: str = "wav" + timeout_sec: float = 30.0 + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + prepare_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + decode_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + + +class SchedulerDebugWorker: + def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): + self.tts = tts + self.max_steps = max_steps + self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 + self.prepare_lock = threading.Lock() + self.condition = threading.Condition() + self.pending_jobs: List[SchedulerPendingJob] = [] + self.running_requests: List[T2SRunningRequest] = [] + self.job_map: dict[str, SchedulerPendingJob] = {} + self.total_finished = 0 + self.total_submitted = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) + self.worker_thread.start() + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: + with self.prepare_lock: + return prepare_request_state(self.tts, spec) + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_ms: float, + prepare_wall_ms: float, + ) -> SchedulerPendingJob: + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_ms=float(prepare_ms), + prepare_wall_ms=float(prepare_wall_ms), + ) + with self.condition: + self.pending_jobs.append(job) + self.job_map[job.request_id] = job + self.total_submitted += 1 + self.condition.notify_all() + return job + + def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None and tracked_job.first_schedule_time is None: + tracked_job.first_schedule_time = started_at + + def _add_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.prefill_ms += elapsed_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.decode_ms += elapsed_ms + job.decode_steps += 1 + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) + phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device) + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=job.state.prompt_semantic, + prompt_phones=job.state.prompt_phones, + refer_spec=job.state.refer_spec, + raw_audio=job.state.raw_audio, + raw_sr=job.state.raw_sr, + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def get_state(self) -> dict: + with self.condition: + return { + "pending_jobs": len(self.pending_jobs), + "running_requests": len(self.running_requests), + "tracked_jobs": len(self.job_map), + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "max_steps": self.max_steps, + "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), + } + + def _finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for item in items: + job = self.job_map.get(item.request_id) + if job is not None: + jobs_to_finalize.append((job, item)) + + for job, item in jobs_to_finalize: + try: + self._sync_device() + synth_start = time.perf_counter() + sample_rate, audio_data = self._synthesize_finished_audio(job, item) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + except Exception as exc: + self._finalize_error([item.request_id], str(exc)) + continue + + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + continue + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + job.synth_ms += synth_ms + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "decode_ms": job.decode_ms, + "synth_ms": job.synth_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is None: + continue + job.error = error + job.done_event.set() + self.job_map.pop(request_id, None) + self.total_finished += 1 + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and not self.running_requests: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def _run_loop(self) -> None: + while True: + wait_for_batch = len(self.running_requests) == 0 + pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) + + if pending_jobs: + try: + self._sync_device() + prefill_start = time.perf_counter() + self._mark_prefill_started(pending_jobs, prefill_start) + admitted_running, admitted_finished = run_prefill_step( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) + self._finalize_finished(admitted_finished) + self.running_requests.extend(admitted_running) + except Exception as exc: + self._finalize_error([job.request_id for job in pending_jobs], str(exc)) + + if self.running_requests: + try: + active_request_ids = [item.state.request_id for item in self.running_requests] + self._sync_device() + decode_start = time.perf_counter() + self.running_requests, step_finished = run_decode_step_for_running( + self.tts.t2s_model.model, + self.running_requests, + max_steps=self.max_steps, + ) + self._sync_device() + self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) + self._finalize_finished(step_finished) + except Exception as exc: + self._finalize_error(active_request_ids, str(exc)) + self.running_requests = [] + continue + + if not pending_jobs: + time.sleep(self.micro_batch_wait_s) + + +scheduler_debug_worker = SchedulerDebugWorker(tts_pipeline) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + "192k", # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + # This will create a wave header then append the frame input + # It should be first on a streaming wav file + # Other frames better should not have it (else you will hear some artifacts each chunk start) + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + return wav_buf.read() + + +def handle_control(command: str): + if command == "restart": + os.execl(sys.executable, sys.executable, *argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def check_params(req: dict): + text: str = req.get("text", "") + text_lang: str = req.get("text_lang", "") + ref_audio_path: str = req.get("ref_audio_path", "") + streaming_mode: bool = req.get("streaming_mode", False) + media_type: str = req.get("media_type", "wav") + prompt_lang: str = req.get("prompt_lang", "") + text_split_method: str = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if text_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text_lang is required"}) + elif text_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, + ) + if prompt_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) + elif prompt_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, + ) + if media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) + # elif media_type == "ogg" and not streaming_mode: + # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + + if text_split_method not in cut_method_names: + return JSONResponse( + status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} + ) + + return None + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def build_scheduler_request_specs(request_items: List[Scheduler_Debug_Request_Item]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(request_items): + payload = item.dict() + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request[{index}] 参数非法: {detail}") + specs.append( + SchedulerRequestSpec( + request_id=payload["request_id"] or f"req_{index:03d}", + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=int(payload["ready_step"]), + ) + ) + return specs + + +def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return [scheduler_debug_worker.prepare_state(spec) for spec in specs] + + +def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: + payload = request.dict() + request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": float(payload["speed_factor"]), + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": payload["media_type"], + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": int(payload["sample_steps"]), + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request 参数非法: {detail}") + return SchedulerRequestSpec( + request_id=request_id, + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=0, + ) + + +async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): + try: + set_scheduler_seed(request.seed) + specs = build_scheduler_request_specs(request.requests) + states = await asyncio.to_thread(prepare_scheduler_states_batch, specs) + finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) + return JSONResponse( + status_code=200, + content={ + "message": "success", + "request_count": len(states), + "max_steps": int(request.max_steps), + "requests": summarize_scheduler_states(states), + "finished": summarize_scheduler_finished(finished), + }, + ) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler debug failed", "Exception": str(e)}, + ) + + +async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): + try: + request_start = time.perf_counter() + spec = build_scheduler_submit_spec(request) + prepare_start = time.perf_counter() + state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec) + prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0 + prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + job = scheduler_debug_worker.submit( + state, + speed_factor=float(request.speed_factor), + sample_steps=int(request.sample_steps), + media_type=request.media_type, + prepare_ms=prepare_ms, + prepare_wall_ms=prepare_wall_ms, + ) + timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec)) + if not timeout_ok: + return JSONResponse( + status_code=202, + content={ + "message": "queued", + "request_id": job.request_id, + "timings": { + "prepare_ms": prepare_ms, + "prepare_wall_ms": prepare_wall_ms, + "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), + }, + "worker_state": scheduler_debug_worker.get_state(), + }, + ) + if job.error is not None: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "request_id": job.request_id, "Exception": job.error}, + ) + if job.audio_data is None or job.sample_rate is None: + return JSONResponse( + status_code=500, + content={ + "message": "scheduler submit failed", + "request_id": job.request_id, + "Exception": "job finished without audio payload", + }, + ) + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_ms = (time.perf_counter() - pack_start) * 1000.0 + job.pack_ms = pack_ms + request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + headers = { + "X-Request-Id": job.request_id, + "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", + "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", + "X-Queue-Wait-Ms": ( + f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Prepare-Ms": f"{prepare_ms:.3f}", + "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Pack-Ms": f"{pack_ms:.3f}", + "X-Worker-Total-Ms": ( + f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Request-Total-Ms": f"{request_total_ms:.3f}", + "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", + } + if job.result is not None: + prepare_profile = job.result.get("prepare_profile", {}) + headers.update( + { + "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", + "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + } + ) + return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "Exception": str(e)}, + ) + + +async def tts_handle(req: dict): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + } + returns: + StreamingResponse: audio stream response. + """ + + streaming_mode = req.get("streaming_mode", False) + return_fragment = req.get("return_fragment", False) + media_type = req.get("media_type", "wav") + + check_res = check_params(req) + if check_res is not None: + return check_res + + if streaming_mode == 0: + streaming_mode = False + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 1: + streaming_mode = False + return_fragment = True + fixed_length_chunk = False + elif streaming_mode == 2: + streaming_mode = True + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 3: + streaming_mode = True + return_fragment = False + fixed_length_chunk = True + + else: + return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) + + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + req["fixed_length_chunk"] = fixed_length_chunk + + print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") + + streaming_mode = streaming_mode or return_fragment + + + try: + tts_generator = tts_pipeline.run(req) + + if streaming_mode: + + def streaming_generator(tts_generator: Generator, media_type: str): + if_frist_chunk = True + for sr, chunk in tts_generator: + if if_frist_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + media_type = "raw" + if_frist_chunk = False + yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + + # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" + return StreamingResponse( + streaming_generator( + tts_generator, + media_type, + ), + media_type=f"audio/{media_type}", + ) + + else: + sr, audio_data = next(tts_generator) + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_data, media_type=f"audio/{media_type}") + except Exception as e: + return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) + + +@APP.get("/control") +async def control(command: str = None): + if command is None: + return JSONResponse(status_code=400, content={"message": "command is required"}) + handle_control(command) + + +@APP.get("/tts") +async def tts_get_endpoint( + text: str = None, + text_lang: str = None, + ref_audio_path: str = None, + aux_ref_audio_paths: list = None, + prompt_lang: str = None, + prompt_text: str = "", + top_k: int = 15, + top_p: float = 1, + temperature: float = 1, + text_split_method: str = "cut5", + batch_size: int = 1, + batch_threshold: float = 0.75, + split_bucket: bool = True, + speed_factor: float = 1.0, + fragment_interval: float = 0.3, + seed: int = -1, + media_type: str = "wav", + parallel_infer: bool = True, + repetition_penalty: float = 1.35, + sample_steps: int = 32, + super_sampling: bool = False, + streaming_mode: Union[bool, int] = False, + overlap_length: int = 2, + min_chunk_length: int = 16, +): + req = { + "text": text, + "text_lang": text_lang.lower(), + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": aux_ref_audio_paths, + "prompt_text": prompt_text, + "prompt_lang": prompt_lang.lower(), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size": int(batch_size), + "batch_threshold": float(batch_threshold), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "fragment_interval": fragment_interval, + "seed": seed, + "media_type": media_type, + "streaming_mode": streaming_mode, + "parallel_infer": parallel_infer, + "repetition_penalty": float(repetition_penalty), + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "min_chunk_length": int(min_chunk_length), + } + return await tts_handle(req) + + +@APP.post("/tts") +async def tts_post_endpoint(request: TTS_Request): + req = request.dict() + return await tts_handle(req) + + +@APP.post("/tts_scheduler_debug") +async def tts_scheduler_debug_endpoint(request: Scheduler_Debug_Request): + return await tts_scheduler_debug_handle(request) + + +@APP.post("/tts_scheduler_submit") +async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request): + return await tts_scheduler_submit_handle(request) + + +@APP.get("/tts_scheduler_state") +async def tts_scheduler_state_endpoint(): + return JSONResponse(status_code=200, content={"message": "success", "worker_state": scheduler_debug_worker.get_state()}) + + +@APP.get("/set_refer_audio") +async def set_refer_aduio(refer_audio_path: str = None): + try: + tts_pipeline.set_ref_audio(refer_audio_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +# @APP.post("/set_refer_audio") +# async def set_refer_aduio_post(audio_file: UploadFile = File(...)): +# try: +# # 检查文件类型,确保是音频文件 +# if not audio_file.content_type.startswith("audio/"): +# return JSONResponse(status_code=400, content={"message": "file type is not supported"}) + +# os.makedirs("uploaded_audio", exist_ok=True) +# save_path = os.path.join("uploaded_audio", audio_file.filename) +# # 保存音频文件到服务器上的一个目录 +# with open(save_path , "wb") as buffer: +# buffer.write(await audio_file.read()) + +# tts_pipeline.set_ref_audio(save_path) +# except Exception as e: +# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)}) +# return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_gpt_weights") +async def set_gpt_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) + tts_pipeline.init_t2s_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) + + return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_sovits_weights") +async def set_sovits_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) + tts_pipeline.init_vits_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +if __name__ == "__main__": + try: + if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None + uvicorn.run(app=APP, host=host, port=port, workers=1) + except Exception: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) diff --git a/tools/t2s_scheduler_prototype.py b/tools/t2s_scheduler_prototype.py new file mode 100644 index 00000000..cd4b9c6d --- /dev/null +++ b/tools/t2s_scheduler_prototype.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SFinishedItem, + T2SRequestState, + prepare_request_state, + run_scheduler_continuous, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-lang", type=str, default="en") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run") + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_pipeline(config_path: Path): + try: + from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。" + ) from exc + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8") + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8") + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + use_cuda = str(tts.configs.device).startswith("cuda") + set_seed(args.seed, use_cuda) + + request_specs = load_request_specs(args) + states = [prepare_request_state(tts, spec) for spec in request_specs] + finished = run_scheduler_continuous(model, states, max_steps=args.max_steps) + + summary = { + "request_count": len(states), + "max_steps": args.max_steps, + "requests": summarise_requests(states), + "finished": summarise_finished(finished), + } + output_path = args.output_dir / "scheduler_prototype_summary.json" + output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(summary, ensure_ascii=False, indent=2)) + print(f"[saved] {output_path}") + + +if __name__ == "__main__": + try: + main() + except ModuleNotFoundError as exc: + print(f"[error] {exc}") + raise SystemExit(1) from None From d245eb169c57c50ea55444bdf104ded247058872 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 01:42:04 +0800 Subject: [PATCH 05/24] Refactor T2S scheduler and inference handling to improve attention mask management and memory tracking. Update T2SRunningRequest and T2SActiveBatch classes to include optional key padding masks. Introduce new benchmarking tools for API performance and memory usage analysis, enhancing overall system efficiency. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 1 + GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 82 +- tools/bench_api_v3_scheduler_submit.py | 250 ++++++ tools/t2s_memory_breakdown.py | 887 +++++++++++++++++++++ 4 files changed, 1192 insertions(+), 28 deletions(-) create mode 100644 tools/bench_api_v3_scheduler_submit.py create mode 100644 tools/t2s_memory_breakdown.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index bd4953df..2fd0df35 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1781,6 +1781,7 @@ class TTS: return audio + @torch.inference_mode() def synthesize_audio_request_local( self, semantic_tokens: torch.Tensor, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index e94a72c7..c8643991 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -70,7 +70,7 @@ class T2SRunningRequest: state: T2SRequestState y_sequence: torch.LongTensor prefix_len: int - decode_attn_mask: torch.Tensor + decode_attn_mask: Optional[torch.Tensor] k_cache: List[torch.Tensor] v_cache: List[torch.Tensor] step_idx: int @@ -93,6 +93,7 @@ class T2SActiveBatch: y_sequences: List[torch.LongTensor] prefix_lens: torch.LongTensor xy_pos: torch.Tensor + key_padding_mask: torch.Tensor prefill_attn_mask: torch.Tensor decode_attn_mask: Optional[torch.Tensor] k_cache: Optional[List[torch.Tensor]] @@ -110,6 +111,7 @@ def normalize_sentence(text: str, language: str) -> str: return text +@torch.inference_mode() def prepare_request_state( tts: Any, spec: SchedulerRequestSpec, @@ -212,6 +214,7 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: ) +@torch.inference_mode() def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: x_items: List[torch.Tensor] = [] y_pos_items: List[torch.Tensor] = [] @@ -240,22 +243,19 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct device = x_batch.device x_lens_tensor = torch.LongTensor(x_lens).to(device) prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) - batch_size = len(states) src_len = max_x_len + max_prefix_len x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) - padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1) + key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool() x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) y_mask = F.pad( torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), (max_x_len, 0), value=False, ) - causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1) - padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1) - attn_mask = causal_mask.logical_or(padding_mask) - attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool() + causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0) + attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1) return T2SActiveBatch( request_ids=[state.request_id for state in states], @@ -265,6 +265,7 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct y_sequences=y_sequences, prefix_lens=prefix_lens_tensor, xy_pos=xy_pos, + key_padding_mask=key_padding_mask, prefill_attn_mask=attn_mask, decode_attn_mask=None, k_cache=None, @@ -322,10 +323,11 @@ def _sample_per_request( finish_reason = "eos_argmax" if finish_reason is not None: + prefix_len = int(active_batch.prefix_lens[batch_index].item()) finished_items.append( T2SFinishedItem( request_id=state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[prefix_len:-1].clone(), finish_idx=step_idx, finish_reason=finish_reason, ) @@ -346,11 +348,7 @@ def decode_one_step( xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( active_batch.xy_pos, active_batch.prefill_attn_mask, None ) - active_batch.decode_attn_mask = F.pad( - active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), - (0, 1), - value=False, - ) + active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) active_batch.prefill_done = True else: xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( @@ -411,6 +409,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: return F.pad(mask, (pad_len, 0), value=True) +def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + if running_request.decode_attn_mask is not None: + return running_request.decode_attn_mask + current_mask_len = running_request.k_cache[0].shape[1] + 1 + return torch.zeros( + (1, 1, 1, current_mask_len), + dtype=torch.bool, + device=running_request.k_cache[0].device, + ) + + +@torch.inference_mode() def run_prefill_step( model: Any, states: Sequence[T2SRequestState], @@ -421,11 +431,9 @@ def run_prefill_step( active_batch = build_prefill_batch(model, states) xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) - decode_attn_mask = F.pad( - active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), - (0, 1), - value=False, - ) + decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if len(states) == 1 and not decode_attn_mask.any().item(): + decode_attn_mask = None logits = model.ar_predict_layer(xy_dec[:, -1]) running_requests: List[T2SRunningRequest] = [] @@ -463,7 +471,7 @@ def run_prefill_step( finished_items.append( T2SFinishedItem( request_id=state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[prefix_len:-1].clone(), finish_idx=0, finish_reason=finish_reason, ) @@ -479,7 +487,11 @@ def run_prefill_step( state=state, y_sequence=new_history, prefix_len=prefix_len, - decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(), + decode_attn_mask=( + None + if decode_attn_mask is None + else decode_attn_mask[batch_index : batch_index + 1].clone() + ), k_cache=request_k_cache, v_cache=request_v_cache, step_idx=1, @@ -492,10 +504,9 @@ def run_prefill_step( def _build_decode_batch_from_running( model: Any, running_requests: Sequence[T2SRunningRequest], -) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]: +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]: xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) - max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests) num_layers = len(running_requests[0].k_cache) batched_k_cache: List[torch.Tensor] = [] @@ -508,13 +519,19 @@ def _build_decode_batch_from_running( torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) ) - batched_decode_attn_mask = torch.cat( - [_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests], - dim=0, - ) + if all(item.decode_attn_mask is None for item in running_requests): + batched_decode_attn_mask = None + else: + materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests] + max_mask_len = max(mask.shape[-1] for mask in materialized_masks) + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks], + dim=0, + ) return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask +@torch.inference_mode() def run_decode_step_for_running( model: Any, running_requests: Sequence[T2SRunningRequest], @@ -568,7 +585,7 @@ def run_decode_step_for_running( finished_items.append( T2SFinishedItem( request_id=running_request.state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[running_request.prefix_len:-1].clone(), finish_idx=current_idx, finish_reason=finish_reason, ) @@ -578,12 +595,20 @@ def run_decode_step_for_running( real_next_kv_len = running_request.k_cache[0].shape[1] + 1 request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + if batched_decode_attn_mask is None: + next_decode_attn_mask = None + else: + current_decode_mask_len = running_request.k_cache[0].shape[1] + 1 + current_decode_attn_mask = batched_decode_attn_mask[ + batch_index : batch_index + 1, :, :, -current_decode_mask_len: + ] + next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) next_running.append( T2SRunningRequest( state=running_request.state, y_sequence=new_history, prefix_len=running_request.prefix_len, - decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False), + decode_attn_mask=next_decode_attn_mask, k_cache=request_k_cache, v_cache=request_v_cache, step_idx=current_idx + 1, @@ -593,6 +618,7 @@ def run_decode_step_for_running( return next_running, finished_items +@torch.inference_mode() def run_scheduler_continuous( model: Any, states: Sequence[T2SRequestState], diff --git a/tools/bench_api_v3_scheduler_submit.py b/tools/bench_api_v3_scheduler_submit.py new file mode 100644 index 00000000..c16468e1 --- /dev/null +++ b/tools/bench_api_v3_scheduler_submit.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import asyncio +import json +import subprocess +import threading +import time +import wave +from pathlib import Path +from typing import Any, Dict, List, Optional + +import httpx + +ROOT_DIR = Path(__file__).resolve().parents[1] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.") + parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880") + parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit") + parser.add_argument("--concurrency", type=int, required=True) + parser.add_argument("--timeout-sec", type=float, default=120.0) + parser.add_argument("--server-pid", type=int, default=None) + parser.add_argument("--poll-interval-sec", type=float, default=0.1) + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--media-type", type=str, default="wav") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--sample-steps", type=int, default=32) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench") + return parser.parse_args() + + +def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]: + wav_paths_all = sorted(args.wav_dir.glob("*.wav")) + wav_paths: List[Path] = [] + for wav_path in wav_paths_all: + with wave.open(str(wav_path), "rb") as handle: + duration = handle.getnframes() / float(handle.getframerate()) + if 3.0 <= duration <= 10.0: + wav_paths.append(wav_path) + if not wav_paths: + raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}") + text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if not text_lines: + raise ValueError(f"没有找到有效文本行: {args.text_file}") + + requests: List[Dict[str, Any]] = [] + for index in range(args.concurrency): + wav_path = wav_paths[index % len(wav_paths)] + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"缺少参考文本: {lab_path}") + requests.append( + { + "request_id": f"bench_{args.concurrency:03d}_{index:03d}", + "text": text_lines[index % len(text_lines)], + "text_lang": args.text_lang, + "ref_audio_path": str(wav_path), + "prompt_lang": args.prompt_lang, + "prompt_text": lab_path.read_text(encoding="utf-8").strip(), + "top_k": int(args.top_k), + "top_p": float(args.top_p), + "temperature": float(args.temperature), + "repetition_penalty": float(args.repetition_penalty), + "sample_steps": int(args.sample_steps), + "media_type": args.media_type, + "timeout_sec": float(args.timeout_sec), + } + ) + return requests + + +class GpuMemoryPoller: + def __init__(self, server_pid: Optional[int], interval_sec: float): + self.server_pid = server_pid + self.interval_sec = interval_sec + self._stop = threading.Event() + self.samples: List[Dict[str, Any]] = [] + self.thread: Optional[threading.Thread] = None + + def _query_memory_mb(self) -> Optional[int]: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-compute-apps=pid,used_gpu_memory", + "--format=csv,noheader,nounits", + ], + check=True, + capture_output=True, + text=True, + ) + except Exception: + return None + total = 0 + found = False + for line in result.stdout.splitlines(): + line = line.strip() + if not line: + continue + parts = [item.strip() for item in line.split(",")] + if len(parts) != 2: + continue + try: + pid = int(parts[0]) + used_mb = int(parts[1]) + except ValueError: + continue + if self.server_pid is None or pid == self.server_pid: + total += used_mb + found = True + if self.server_pid is None: + return total + return total if found else 0 + + def _run(self) -> None: + while not self._stop.is_set(): + used_mb = self._query_memory_mb() + self.samples.append({"ts": time.time(), "used_mb": used_mb}) + self._stop.wait(self.interval_sec) + + def start(self) -> None: + self.thread = threading.Thread(target=self._run, daemon=True) + self.thread.start() + + def stop(self) -> None: + self._stop.set() + if self.thread is not None: + self.thread.join(timeout=2.0) + + def summary(self) -> Dict[str, Any]: + valid = [item for item in self.samples if item["used_mb"] is not None] + peak = max(valid, key=lambda item: item["used_mb"]) if valid else None + first = valid[0] if valid else None + last = valid[-1] if valid else None + return { + "server_pid": self.server_pid, + "sample_count": int(len(self.samples)), + "start_used_mb": None if first is None else int(first["used_mb"]), + "peak_used_mb": None if peak is None else int(peak["used_mb"]), + "peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]), + "end_used_mb": None if last is None else int(last["used_mb"]), + "peak_ts": None if peak is None else float(peak["ts"]), + "samples": self.samples, + } + + +async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]: + started = time.perf_counter() + try: + response = await client.post(url, json=payload) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + item = { + "request_id": payload["request_id"], + "status_code": int(response.status_code), + "elapsed_ms": float(elapsed_ms), + "content_type": response.headers.get("content-type"), + "audio_bytes": int(len(response.content)), + "headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")}, + } + if response.status_code != 200: + try: + item["error_body"] = response.json() + except Exception: + item["error_body"] = response.text + return item + except Exception as exc: + return { + "request_id": payload["request_id"], + "status_code": -1, + "elapsed_ms": float((time.perf_counter() - started) * 1000.0), + "exception": repr(exc), + } + + +async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]: + payloads = load_requests(args) + url = args.base_url.rstrip("/") + args.endpoint + poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec) + + limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency) + timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0) + + started = time.perf_counter() + poller.start() + try: + async with httpx.AsyncClient(limits=limits, timeout=timeout) as client: + results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads]) + finally: + poller.stop() + wall_ms = (time.perf_counter() - started) * 1000.0 + + ok_results = [item for item in results if item["status_code"] == 200] + failed_results = [item for item in results if item["status_code"] != 200] + request_total_ms = [] + worker_total_ms = [] + for item in ok_results: + headers = item.get("headers", {}) + if "x-request-total-ms" in headers: + request_total_ms.append(float(headers["x-request-total-ms"])) + if "x-worker-total-ms" in headers: + worker_total_ms.append(float(headers["x-worker-total-ms"])) + + return { + "concurrency": int(args.concurrency), + "server_pid": args.server_pid, + "request_count": int(len(payloads)), + "wall_ms": float(wall_ms), + "success_count": int(len(ok_results)), + "failure_count": int(len(failed_results)), + "request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None, + "request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None, + "worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None, + "worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None, + "gpu_memory": poller.summary(), + "results": results, + } + + +def main() -> None: + args = parse_args() + output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}" + output_dir.mkdir(parents=True, exist_ok=True) + summary = asyncio.run(run_benchmark(args)) + summary_path = output_dir / "summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps({ + "concurrency": summary["concurrency"], + "success_count": summary["success_count"], + "failure_count": summary["failure_count"], + "wall_ms": summary["wall_ms"], + "gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"], + "request_total_ms_avg": summary["request_total_ms_avg"], + "request_total_ms_max": summary["request_total_ms_max"], + "summary_path": str(summary_path), + }, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/tools/t2s_memory_breakdown.py b/tools/t2s_memory_breakdown.py new file mode 100644 index 00000000..18127953 --- /dev/null +++ b/tools/t2s_memory_breakdown.py @@ -0,0 +1,887 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import gc +import contextlib +import json +import random +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402 +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SRequestState, + T2SRunningRequest, + _build_decode_batch_from_running, + build_prefill_batch, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"]) + parser.add_argument("--auto-count", type=int, default=4) + parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--warmup", action="store_true", default=False) + parser.add_argument("--worker-rounds", type=int, default=1) + parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"]) + parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False) + parser.add_argument( + "--output-dir", + type=Path, + default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1", + ) + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +def bytes_to_mb(num_bytes: int) -> float: + return float(num_bytes) / (1024.0 * 1024.0) + + +def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int: + if tensor is None: + return 0 + return int(tensor.numel() * tensor.element_size()) + + +def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int: + return int(sum(tensor_nbytes(item) for item in items)) + + +def model_nbytes(module: torch.nn.Module) -> int: + total = 0 + for parameter in module.parameters(): + total += tensor_nbytes(parameter) + for buffer in module.buffers(): + total += tensor_nbytes(buffer) + return int(total) + + +def build_module_weight_summary(tts: TTS) -> Dict[str, Any]: + modules = { + "t2s_model": tts.t2s_model, + "t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None, + "vits_model": tts.vits_model, + "bert_model": tts.bert_model, + "cnhuhbert_model": tts.cnhuhbert_model, + "vocoder": tts.vocoder, + "sv_model": tts.sv_model, + } + by_module = {} + total_bytes = 0 + for name, module in modules.items(): + module_bytes = model_nbytes(module) if module is not None else 0 + by_module[name] = { + "bytes": int(module_bytes), + "mb": bytes_to_mb(module_bytes), + } + total_bytes += module_bytes + return { + "by_module": by_module, + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + } + + +def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]: + storages: Dict[int, Dict[str, Any]] = {} + tensor_views: List[Dict[str, Any]] = [] + for obj in gc.get_objects(): + try: + tensor = None + if torch.is_tensor(obj): + tensor = obj + elif hasattr(obj, "data") and torch.is_tensor(obj.data): + tensor = obj.data + if tensor is None or not tensor.is_cuda: + continue + storage = tensor.untyped_storage() + storage_ptr = int(storage.data_ptr()) + if storage_ptr not in storages: + storages[storage_ptr] = { + "storage_ptr": storage_ptr, + "storage_bytes": int(storage.nbytes()), + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + "device": str(tensor.device), + } + tensor_views.append( + { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "bytes": tensor_nbytes(tensor), + "device": str(tensor.device), + } + ) + except Exception: + continue + storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True) + tensor_views.sort(key=lambda item: item["bytes"], reverse=True) + return { + "unique_storage_count": int(len(storage_list)), + "unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)), + "unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)), + "top_storages": storage_list[:top_k], + "top_tensor_views": tensor_views[:top_k], + } + + +def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip() + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count] + if len(wav_paths) < args.auto_count: + raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav") + text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if len(text_lines) < args.auto_count: + raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本") + specs: List[SchedulerRequestSpec] = [] + for index, wav_path in enumerate(wav_paths): + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"找不到参考文本 {lab_path}") + specs.append( + SchedulerRequestSpec( + request_id=f"req_{index:03d}", + ref_audio_path=wav_path, + prompt_text=lab_path.read_text(encoding="utf-8").strip(), + prompt_lang="zh", + text=text_lines[index], + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ) + return specs + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8").strip() + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + if args.scenario == "single": + return build_single_spec(args) + return build_auto_specs(args) + + +def load_pipeline(config_path: Path) -> TTS: + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def cuda_mem_snapshot(device: Any) -> Dict[str, float]: + if not (str(device).startswith("cuda") and torch.cuda.is_available()): + return { + "allocated_mb": 0.0, + "reserved_mb": 0.0, + "max_allocated_mb": 0.0, + "max_reserved_mb": 0.0, + } + _sync_device(device) + return { + "allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)), + "reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)), + "max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)), + "max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)), + } + + +def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]: + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.reset_peak_memory_stats(device) + before = cuda_mem_snapshot(device) + started = time.perf_counter() + result = fn() + _sync_device(device) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + after = cuda_mem_snapshot(device) + after["elapsed_ms"] = float(elapsed_ms) + after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"]) + after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"]) + after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0)) + return result, after + + +class GlobalPeakRecorder: + def __init__(self, device: Any): + self.device = device + self.checkpoints: List[Dict[str, Any]] = [] + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + def record(self, label: str, **extra: Any) -> None: + snapshot = cuda_mem_snapshot(self.device) + snapshot["label"] = label + snapshot.update(extra) + self.checkpoints.append(snapshot) + + def summary(self) -> Dict[str, Any]: + peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None + return { + "peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]), + "peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]), + "peak_label": None if peak is None else peak["label"], + "checkpoints": self.checkpoints, + } + + +def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]: + per_request = [] + total = { + "phones_bytes": 0, + "prompt_phones_bytes": 0, + "all_phones_bytes": 0, + "all_bert_features_bytes": 0, + "prompt_semantic_bytes": 0, + "refer_spec_bytes": 0, + "raw_audio_bytes": 0, + "audio_16k_bytes": 0, + } + for state in states: + spec_audio, audio_16k = state.refer_spec + item = { + "request_id": state.request_id, + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "phones_len": int(state.phones.shape[0]), + "all_phones_len": int(state.all_phones.shape[0]), + "bert_frames": int(state.all_bert_features.shape[-1]), + "phones_bytes": tensor_nbytes(state.phones), + "prompt_phones_bytes": tensor_nbytes(state.prompt_phones), + "all_phones_bytes": tensor_nbytes(state.all_phones), + "all_bert_features_bytes": tensor_nbytes(state.all_bert_features), + "prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic), + "refer_spec_bytes": tensor_nbytes(spec_audio), + "audio_16k_bytes": tensor_nbytes(audio_16k), + "raw_audio_bytes": tensor_nbytes(state.raw_audio), + } + for key in total: + total[key] += int(item[key]) + per_request.append(item) + total["total_bytes"] = int(sum(total.values())) + total["total_mb"] = bytes_to_mb(total["total_bytes"]) + return {"per_request": per_request, "total": total} + + +def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]: + y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences)) + fields = { + "x_bytes": tensor_nbytes(active_batch.x), + "x_lens_bytes": tensor_nbytes(active_batch.x_lens), + "prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens), + "xy_pos_bytes": tensor_nbytes(active_batch.xy_pos), + "key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask), + "prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask), + "y_sequence_bytes": y_sequence_bytes, + } + fields["total_bytes"] = int(sum(fields.values())) + fields["total_mb"] = bytes_to_mb(fields["total_bytes"]) + fields["batch_size"] = int(len(active_batch.states)) + fields["max_x_len"] = int(active_batch.x.shape[1]) + fields["src_len"] = int(active_batch.xy_pos.shape[1]) + fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape) + return fields + + +def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]: + per_request = [] + total_private_k_bytes = 0 + total_private_v_bytes = 0 + total_decode_mask_bytes = 0 + total_y_sequence_bytes = 0 + for item in running_requests: + k_bytes = tensor_list_nbytes(item.k_cache) + v_bytes = tensor_list_nbytes(item.v_cache) + mask_bytes = tensor_nbytes(item.decode_attn_mask) + y_bytes = tensor_nbytes(item.y_sequence) + total_private_k_bytes += k_bytes + total_private_v_bytes += v_bytes + total_decode_mask_bytes += mask_bytes + total_y_sequence_bytes += y_bytes + per_request.append( + { + "request_id": item.state.request_id, + "step_idx": int(item.step_idx), + "prefix_len": int(item.prefix_len), + "history_len": int(item.y_sequence.shape[0]), + "kv_len": int(item.k_cache[0].shape[1]), + "k_cache_bytes": k_bytes, + "v_cache_bytes": v_bytes, + "decode_mask_bytes": mask_bytes, + "y_sequence_bytes": y_bytes, + } + ) + total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes + return { + "per_request": per_request, + "totals": { + "private_k_cache_bytes": int(total_private_k_bytes), + "private_v_cache_bytes": int(total_private_v_bytes), + "private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes), + "decode_mask_bytes": int(total_decode_mask_bytes), + "y_sequence_bytes": int(total_y_sequence_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + }, + } + + +def summarise_decode_batch( + xy_pos: torch.Tensor, + batched_k_cache: Sequence[torch.Tensor], + batched_v_cache: Sequence[torch.Tensor], + batched_decode_attn_mask: Optional[torch.Tensor], + running_requests: Sequence[T2SRunningRequest], +) -> Dict[str, Any]: + private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests)) + private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests)) + batched_k_bytes = tensor_list_nbytes(batched_k_cache) + batched_v_bytes = tensor_list_nbytes(batched_v_cache) + batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask) + xy_pos_bytes = tensor_nbytes(xy_pos) + total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes + return { + "batch_size": int(len(running_requests)), + "xy_pos_bytes": int(xy_pos_bytes), + "batched_k_cache_bytes": int(batched_k_bytes), + "batched_v_cache_bytes": int(batched_v_bytes), + "batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes), + "batched_decode_mask_bytes": int(batched_mask_bytes), + "private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes), + "kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_pos_shape": list(xy_pos.shape), + "batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape), + "layer_k_cache_shape": list(batched_k_cache[0].shape), + } + + +def summarise_decode_outputs( + xy_dec: torch.Tensor, + next_k_cache: Sequence[torch.Tensor], + next_v_cache: Sequence[torch.Tensor], +) -> Dict[str, Any]: + xy_dec_bytes = tensor_nbytes(xy_dec) + next_k_bytes = tensor_list_nbytes(next_k_cache) + next_v_bytes = tensor_list_nbytes(next_v_cache) + total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes + return { + "xy_dec_bytes": int(xy_dec_bytes), + "next_k_cache_bytes": int(next_k_bytes), + "next_v_cache_bytes": int(next_v_bytes), + "next_kv_cache_bytes": int(next_k_bytes + next_v_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_dec_shape": list(xy_dec.shape), + "layer_next_k_cache_shape": list(next_k_cache[0].shape), + } + + +def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]: + ranking = [ + ("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]), + ("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]), + ("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]), + ("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]), + ("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]), + ("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]), + ("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]), + ] + ranking.sort(key=lambda item: item[1], reverse=True) + return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking] + + +def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]: + semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device) + phones = state.phones.unsqueeze(0).to(tts.configs.device) + audio_fragment = tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=state.prompt_semantic, + prompt_phones=state.prompt_phones, + refer_spec=state.refer_spec, + raw_audio=state.raw_audio, + raw_sr=state.raw_sr, + speed=1.0, + sample_steps=32, + ) + output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"] + return tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=1.0, + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + +def simulate_worker_end_to_end( + tts: TTS, + specs: Sequence[SchedulerRequestSpec], + max_steps: int, + rounds: int, + grad_mode: str = "default", +) -> Dict[str, Any]: + device = tts.configs.device + recorder = GlobalPeakRecorder(device) + recorder.record("after_model_load") + + state_map: Dict[str, T2SRequestState] = {} + per_round: List[Dict[str, Any]] = [] + + for round_index in range(rounds): + grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext + with grad_context(): + states = [prepare_request_state(tts, spec) for spec in specs] + state_map = {state.request_id: state for state in states} + recorder.record( + "after_prepare_states", + round_index=int(round_index), + request_count=int(len(states)), + grad_mode=grad_mode, + ) + + pending = list(states) + running_requests: List[T2SRunningRequest] = [] + round_events: List[Dict[str, Any]] = [] + current_tick = 0 + + while pending or running_requests: + admitted = pending + pending = [] + + if admitted: + recorder.record( + "before_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_count=int(len(admitted)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps) + recorder.record( + "after_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_running_count=int(len(admitted_running)), + admitted_finished_count=int(len(admitted_finished)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "prefill", + "admitted_count": int(len(admitted)), + "admitted_running_count": int(len(admitted_running)), + "admitted_finished_count": int(len(admitted_finished)), + } + ) + for item in admitted_finished: + recorder.record( + "before_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + running_requests.extend(admitted_running) + recorder.record( + "after_extend_running", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + + if running_requests: + recorder.record( + "before_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + running_requests, step_finished = run_decode_step_for_running( + tts.t2s_model.model, + running_requests, + max_steps=max_steps, + ) + recorder.record( + "after_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_count=int(len(step_finished)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "decode", + "running_count_after_decode": int(len(running_requests)), + "finished_count": int(len(step_finished)), + } + ) + for item in step_finished: + recorder.record( + "before_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + current_tick += 1 + + recorder.record( + "after_round_complete", + round_index=int(round_index), + running_count=0, + grad_mode=grad_mode, + ) + per_round.append( + { + "round_index": int(round_index), + "events": round_events, + } + ) + + return { + "grad_mode": grad_mode, + "rounds": per_round, + "timeline": recorder.summary(), + } + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + device = tts.configs.device + use_cuda = str(device).startswith("cuda") and torch.cuda.is_available() + set_seed(args.seed, use_cuda) + + specs = load_request_specs(args) + if args.early_stop_num == -1: + for spec in specs: + spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec) + + if args.warmup and specs: + warmup_spec = specs[:1] + _ = [prepare_request_state(tts, spec) for spec in warmup_spec] + gc.collect() + if use_cuda: + torch.cuda.empty_cache() + _sync_device(device) + + states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs]) + request_state_summary = summarise_state_tensors(states) + + active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states)) + prefill_batch_tensor_summary = summarise_prefill_batch(active_batch) + + prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps)) + running_requests, finished_items = prefill_result + running_requests_summary = summarise_running_requests(running_requests) + finished_after_prefill_summary = [ + { + "request_id": item.request_id, + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + } + for item in finished_items + ] + + if not running_requests: + raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}") + + decode_batch_result, decode_batch_mem = stage_run( + device, + lambda: _build_decode_batch_from_running(model, running_requests), + ) + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result + decode_batch_tensor_summary = summarise_decode_batch( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + running_requests, + ) + + decode_out_result, decode_step_mem = stage_run( + device, + lambda: model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ), + ) + xy_dec, next_k_cache, next_v_cache = decode_out_result + decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache) + del active_batch + del running_requests + del finished_items + del xy_pos + del batched_k_cache + del batched_v_cache + del batched_decode_attn_mask + del xy_dec + del next_k_cache + del next_v_cache + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + end_to_end_worker = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode=args.worker_grad_mode, + ) + live_cuda_tensors_after_worker = snapshot_live_cuda_tensors() + worker_inference_mode = None + if args.compare_worker_grad_modes: + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + worker_inference_mode = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode="inference_mode", + ) + + summary = { + "meta": { + "scenario": args.scenario if args.request_manifest is None else "manifest", + "seed": int(args.seed), + "device": str(device), + "dtype": str(next(model.parameters()).dtype), + "request_count": int(len(specs)), + "num_layers": int(model.num_layers), + "num_heads": int(model.num_head), + "model_dim": int(model.model_dim), + "model_weights_mb": bytes_to_mb(model_nbytes(model)), + }, + "loaded_module_weights": build_module_weight_summary(tts), + "requests": [ + { + "request_id": spec.request_id, + "ref_audio_path": str(spec.ref_audio_path), + "prompt_text": spec.prompt_text, + "text": spec.text, + } + for spec in specs + ], + "prepare_stage": { + "memory": prepare_mem, + "request_state": request_state_summary, + }, + "prefill_batch": { + "memory": prefill_batch_mem, + "tensor_bytes": prefill_batch_tensor_summary, + }, + "prefill_step": { + "memory": prefill_step_mem, + "running_requests": running_requests_summary, + "finished_after_prefill": finished_after_prefill_summary, + }, + "decode_batch": { + "memory": decode_batch_mem, + "tensor_bytes": decode_batch_tensor_summary, + }, + "decode_outputs": { + "memory": decode_step_mem, + "tensor_bytes": decode_output_tensor_summary, + }, + "end_to_end_worker": end_to_end_worker, + "live_cuda_tensors_after_worker": live_cuda_tensors_after_worker, + "end_to_end_worker_inference_mode": worker_inference_mode, + } + summary["top_rankings"] = top_rankings(summary) + + summary_path = args.output_dir / "t2s_memory_breakdown_summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + + print(json.dumps(summary["meta"], ensure_ascii=False, indent=2)) + print("[top_rankings]") + for item in summary["top_rankings"]: + print(f"- {item['name']}: {item['mb']:.3f} MB") + print("[worker_peak]") + print( + json.dumps( + { + "peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"], + "peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + if worker_inference_mode is not None: + print("[worker_peak_inference_mode]") + print( + json.dumps( + { + "peak_label": worker_inference_mode["timeline"]["peak_label"], + "peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + print(f"[summary] {summary_path}") + + +if __name__ == "__main__": + main() From 845b181360b5c5f4a5ed827ca01cfd5dfd8f58b1 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 05:19:28 +0800 Subject: [PATCH 06/24] Implement batch processing for BERT and reference semantic tasks in TTS. Introduce StageLimiter for managing concurrent processing and enhance the TTS class with new methods for handling audio and semantic extraction. Update profiling metrics for better performance tracking during inference. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 261 ++++++++++++++--- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 259 +++++++++++------ .../prepare_bert_batch_worker.py | 197 +++++++++++++ .../prepare_ref_semantic_batch_worker.py | 262 ++++++++++++++++++ GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 121 ++++++-- api_v3.py | 64 ++++- 6 files changed, 1024 insertions(+), 140 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 2fd0df35..9c259662 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -5,6 +5,7 @@ import random import sys import time import traceback +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy import torchaudio @@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list from TTS_infer_pack.text_segmentation_method import splits -from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( + PrepareRefSemanticBatchWorker, + prepare_prompt_semantic_wav16k, +) from sv import SV resample_transform_dict = {} @@ -442,11 +448,56 @@ class TTS: "upsample_rate": None, "overlapped_len": None, } + self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2"))) + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None + default_text_cpu_workers = 16 + self.prepare_text_cpu_workers = max( + 0, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))), + ) + self.prepare_text_cpu_executor = None + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_executor = ThreadPoolExecutor( + max_workers=self.prepare_text_cpu_workers, + thread_name_prefix="prepare-text-cpu", + ) self._init_models() + if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": + self.prepare_bert_batch_worker = PrepareBertBatchWorker( + bert_model=self.bert_model, + tokenizer=self.bert_tokenizer, + device=self.configs.device, + stage_limiter=self.prepare_bert_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), + max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + ) + if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") + if ref_max_batch_samples is None: + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000") + self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker( + ssl_model=self.cnhuhbert_model, + vits_model=self.vits_model, + device=self.configs.device, + is_half=self.configs.is_half, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + stage_limiter=self.prepare_ref_audio_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")), + max_batch_samples=int(ref_max_batch_samples), + ) + self.text_preprocessor: TextPreprocessor = TextPreprocessor( - self.bert_model, self.bert_tokenizer, self.configs.device + self.bert_model, + self.bert_tokenizer, + self.configs.device, + bert_stage_limiter=self.prepare_bert_stage_limiter, + bert_batch_worker=self.prepare_bert_batch_worker, ) self.prompt_cache: dict = { @@ -755,47 +806,52 @@ class TTS: Args: ref_audio_path: str, the path of the reference audio. """ - self._set_prompt_semantic(ref_audio_path) - self._set_ref_spec(ref_audio_path) + bundle = self.extract_ref_audio_bundle(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [bundle["refer_spec"]] + else: + self.prompt_cache["refer_spec"][0] = bundle["refer_spec"] + self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"] + self.prompt_cache["raw_audio"] = bundle["raw_audio"] + self.prompt_cache["raw_sr"] = bundle["raw_sr"] self._set_ref_audio_path(ref_audio_path) - def extract_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - return prompt_semantic - - def extract_ref_spec(self, ref_audio_path: str): + def _load_ref_audio_raw(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) - raw_audio = raw_audio.to(self.configs.device).float() + return raw_audio.float(), int(raw_sr) + + @torch.inference_mode() + def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor): + wav16k = wav16k.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) + codes = self.vits_model.extract_latent(hubert_feature) + return codes[0, 0].to(self.configs.device) + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + return self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + + def extract_prompt_semantic(self, ref_wav_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) + return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + + def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + raw_audio_device = raw_audio.to(self.configs.device).float() if raw_sr != self.configs.sampling_rate: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) else: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) @@ -820,8 +876,141 @@ class TTS: audio = None return spec, audio, raw_audio, raw_sr - def extract_text_features(self, text: str, language: str): - return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + def extract_ref_spec(self, ref_audio_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + return self._extract_ref_spec_from_raw(raw_audio, raw_sr) + + def extract_ref_audio_bundle(self, ref_audio_path: str): + load_start = time.perf_counter() + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + if self.prepare_ref_semantic_batch_worker is None: + with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: + prompt_semantic_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(limiter_stats["wait_ms"]) + audio_stage_slots = float(limiter_stats["slots"]) + audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) + prompt_semantic_profile = { + "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": prompt_semantic_ms, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + ref_spec_wait_ms = 0.0 + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0) + ), + "prompt_semantic_scatter_ms": float( + prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + "prompt_semantic_stage_slots": float( + prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0) + ), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": ref_spec_wait_ms, + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + prompt_semantic_profile = { + "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": 0.0, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": 0.0, + "prompt_semantic_stage_inflight_peak": 0.0, + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + if self.prepare_ref_semantic_batch_worker is not None: + prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + prompt_semantic_profile.update(worker_profile) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_limiter_stats["wait_ms"] + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_limiter_stats["slots"]), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_limiter_stats["peak_inflight"]), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + def extract_text_features(self, text: str, language: str, profile: dict | None = None): + return self.text_preprocessor.segment_and_extract_feature_for_text( + text, language, self.configs.version, profile=profile + ) def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 283e91c3..15b3c322 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,6 +1,8 @@ import os import sys import threading +import time +from contextlib import contextmanager from tqdm import tqdm @@ -16,6 +18,7 @@ from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker from tools.i18n.i18n import I18nAuto, scan_language_list @@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list: return result +class StageLimiter: + def __init__(self, slots: int): + self.slots = max(1, int(slots)) + self.semaphore = threading.BoundedSemaphore(self.slots) + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + + @contextmanager + def enter(self): + wait_start = time.perf_counter() + self.semaphore.acquire() + wait_ms = (time.perf_counter() - wait_start) * 1000.0 + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + peak_inflight = self.peak_inflight + try: + yield { + "wait_ms": wait_ms, + "inflight": current_inflight, + "peak_inflight": peak_inflight, + "slots": self.slots, + } + finally: + with self.lock: + self.inflight = max(0, self.inflight - 1) + self.semaphore.release() + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "slots": self.slots, + "inflight": self.inflight, + "peak_inflight": self.peak_inflight, + } + + class TextPreprocessor: - def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): + def __init__( + self, + bert_model: AutoModelForMaskedLM, + tokenizer: AutoTokenizer, + device: torch.device, + bert_stage_limiter: StageLimiter | None = None, + bert_batch_worker: PrepareBertBatchWorker | None = None, + ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device - self.bert_lock = threading.RLock() + self.bert_stage_limiter = bert_stage_limiter + self.bert_batch_worker = bert_batch_worker def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") @@ -115,86 +166,136 @@ class TextPreprocessor: return texts def segment_and_extract_feature_for_text( - self, text: str, language: str, version: str = "v1" + self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version) + return self.get_phones_and_bert(text, language, version, profile=profile) - def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): - with self.bert_lock: - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): + def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - # print(textlist) - # print(langlist) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True) + def get_phones_and_bert( + self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None + ): + text = re.sub(r' {2,}', ' ', text) + textlist, langlist = self._split_text_by_language(text, language) + phones_list = [] + bert_list = [] + norm_text_list = [] + for segment_text, segment_lang in zip(textlist, langlist): + phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version) + bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) - return phones, bert, norm_text + if not final and len(phones) < 6: + return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile) - def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: - with torch.no_grad(): - inputs = self.tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(self.device) - res = self.bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + return phones, bert, norm_text + + def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(profile.get(key, 0.0)) + float(value) + + def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: + if self.bert_batch_worker is not None: + feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak( + profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0) + ) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + return feature + + limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} + if self.bert_stage_limiter is None: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.bert_stage_limiter.enter() as limiter_stats: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"]) + self._accumulate_profile(profile, "bert_forward_ms", forward_ms) + self._accumulate_profile(profile, "bert_calls", 1.0) + self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"]) + if profile is not None: + profile["bert_stage_slots"] = float(limiter_stats["slots"]) assert len(word2ph) == len(text) phone_level_feature = [] for i in range(len(word2ph)): @@ -209,10 +310,10 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): + def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None): language = language.replace("all_", "") if language == "zh": - feature = self.get_bert_feature(norm_text, word2ph).to(self.device) + feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( (1024, len(phones)), @@ -236,4 +337,4 @@ class TextPreprocessor: punctuations = "".join(re.escape(p) for p in punctuation) pattern = f"([{punctuations}])([{punctuations}])+" result = re.sub(pattern, r"\1", text) - return result \ No newline at end of file + return result diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py new file mode 100644 index 00000000..b1ede3d8 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -0,0 +1,197 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import torch + + +@dataclass +class BertFeatureTask: + norm_text: str + word2ph: List[int] + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_feature: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareBertBatchWorker: + def __init__( + self, + bert_model, + tokenizer, + device, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 16, + max_batch_tokens: int = 4096, + ): + self.bert_model = bert_model + self.tokenizer = tokenizer + self.device = device + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_tokens = max(16, int(max_batch_tokens)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[BertFeatureTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) + self.worker_thread.start() + + def _estimate_task_tokens(self, task: BertFeatureTask) -> int: + return max(1, len(task.norm_text) + 2) + + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_feature is not None + return task.result_feature, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_tokens": self.max_batch_tokens, + } + + def _collect_batch(self) -> List[BertFeatureTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] + batch_tokens = self._estimate_task_tokens(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_tokens = self._estimate_task_tokens(next_task) + if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens: + break + batch.append(self.pending_tasks.popleft()) + batch_tokens += next_tokens + + self.active_batch_size = len(batch) + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + return batch + + def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _run_batch(self, batch: List[BertFeatureTask]) -> None: + batch_started = time.perf_counter() + texts = [task.norm_text for task in batch] + batch_tokens = sum(self._estimate_task_tokens(task) for task in batch) + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + hidden = outputs["hidden_states"][-3].detach().cpu() + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + text_len = len(task.word2ph) + if text_len != len(task.norm_text): + raise AssertionError( + f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}" + ) + seq_len = int(attention_mask_cpu[batch_index].sum().item()) + char_features = hidden[batch_index, 1 : seq_len - 1] + if char_features.shape[0] != text_len: + raise AssertionError( + f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}" + ) + phone_level_feature = [] + for char_index, repeat_count in enumerate(task.word2ph): + phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1)) + task.result_feature = torch.cat(phone_level_feature, dim=0).T + task.profile = { + "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_forward_ms": float(forward_ms), + "bert_tokenize_ms": float(tokenize_ms), + "bert_scatter_ms": 0.0, + "bert_calls": 1.0, + "bert_stage_slots": float(limiter_stats["slots"]), + "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "bert_batch_size": float(len(batch)), + "bert_batch_tokens": float(batch_tokens), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_feature is not None: + task.profile["bert_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py new file mode 100644 index 00000000..7a1f9a53 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -0,0 +1,262 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import librosa +import numpy as np +import torch + + +REF_AUDIO_MIN_SAMPLES_16K = 48000 +REF_AUDIO_MAX_SAMPLES_16K = 160000 + + +def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: + wav_mono = raw_audio + if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: + wav_mono = wav_mono.mean(0, keepdim=True) + wav16k = wav_mono.squeeze(0).cpu().numpy() + if raw_sr != 16000: + wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000) + if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: + raise OSError("参考音频在3~10秒范围外,请更换!") + wav16k = np.ascontiguousarray(wav16k, dtype=np.float32) + if zero_wav_samples > 0: + wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0) + return torch.from_numpy(wav16k) + + +def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: + if conv1d is None: + return input_lengths.to(dtype=torch.long) + kernel_size = int(conv1d.kernel_size[0]) + stride = int(conv1d.stride[0]) + padding = int(conv1d.padding[0]) + dilation = int(conv1d.dilation[0]) + output_lengths = torch.div( + input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1, + stride, + rounding_mode="floor", + ) + 1 + return torch.clamp(output_lengths, min=0).to(dtype=torch.long) + + +@dataclass +class RefSemanticTask: + raw_audio: torch.Tensor + raw_sr: int + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_prompt_semantic: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareRefSemanticBatchWorker: + def __init__( + self, + ssl_model, + vits_model, + device, + is_half: bool, + zero_wav_samples: int, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 8, + max_batch_samples: int = 960000, + ): + self.ssl_model = ssl_model + self.vits_model = vits_model + self.device = device + self.is_half = bool(is_half) + self.zero_wav_samples = max(0, int(zero_wav_samples)) + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[RefSemanticTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.active_batch_samples = 0 + self.active_batch_samples_peak = 0 + self.worker_thread = threading.Thread( + target=self._run_loop, + name="prepare-ref-semantic-batch-worker", + daemon=True, + ) + self.worker_thread.start() + + def _estimate_task_samples(self, task: RefSemanticTask) -> int: + raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0 + base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr)))) + return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples + + def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_prompt_semantic is not None + return task.result_prompt_semantic, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "active_batch_samples": self.active_batch_samples, + "active_batch_samples_peak": self.active_batch_samples_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_samples": self.max_batch_samples, + } + + def _collect_batch(self) -> List[RefSemanticTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] + batch_samples = self._estimate_task_samples(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_samples = self._estimate_task_samples(next_task) + if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: + break + batch.append(self.pending_tasks.popleft()) + batch_samples += next_samples + + self.active_batch_size = len(batch) + self.active_batch_samples = batch_samples + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + if self.active_batch_samples > self.active_batch_samples_peak: + self.active_batch_samples_peak = self.active_batch_samples + return batch + + def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.active_batch_samples = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor: + model = self.ssl_model.model + if hasattr(model, "_get_feature_vector_attention_mask"): + feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask) + return feature_mask.to(dtype=torch.long).sum(dim=1) + raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1) + if hasattr(model, "_get_feat_extract_output_lengths"): + return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long) + return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) + + @torch.inference_mode() + def _run_batch(self, batch: List[RefSemanticTask]) -> None: + batch_started = time.perf_counter() + prepared_start = time.perf_counter() + prepared_wavs = [ + prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch + ] + cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0 + wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long) + batch_samples = int(wav_lengths.sum().item()) + max_wav_len = int(wav_lengths.max().item()) + + input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) + attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) + for batch_index, wav in enumerate(prepared_wavs): + wav_len = int(wav.shape[0]) + input_values_cpu[batch_index, :wav_len] = wav + attention_mask_cpu[batch_index, :wav_len] = 1 + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + code_len = int(code_lengths[batch_index].item()) + task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + task.profile = { + "prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_calls": 1.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": float(len(batch)), + "prompt_semantic_batch_samples": float(batch_samples), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_prompt_semantic is not None: + task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index c8643991..de498573 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,5 +1,6 @@ from __future__ import annotations +from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path import time @@ -123,31 +124,58 @@ def prepare_request_state( prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) text = spec.text.strip("\n") - _sync_device(device) - prompt_text_features_start = time.perf_counter() - prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) - _sync_device(device) - prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile: Dict[str, float] = {} + text_features_profile: Dict[str, float] = {} + text_feature_pair_start = time.perf_counter() + prompt_future: Future | None = None + + def _extract_prompt_features(): + _sync_device(device) + prompt_start = time.perf_counter() + result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile) + _sync_device(device) + return result, (time.perf_counter() - prompt_start) * 1000.0 + + if getattr(tts, "prepare_text_cpu_executor", None) is not None: + prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features) _sync_device(device) text_features_start = time.perf_counter() - phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile) _sync_device(device) text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + + if prompt_future is None: + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features( + prompt_text, spec.prompt_lang, profile=prompt_text_profile + ) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile["parallel_future_wait_ms"] = 0.0 + else: + prompt_wait_start = time.perf_counter() + (prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result() + prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0 + + text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0 if phones is None: raise ValueError(f"{spec.request_id} text preprocessing returned no phones") _sync_device(device) - prompt_semantic_start = time.perf_counter() - prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + ref_audio_bundle_start = time.perf_counter() + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + prompt_semantic = ref_audio_bundle["prompt_semantic"].long() + spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + raw_audio = ref_audio_bundle["raw_audio"] + raw_sr = int(ref_audio_bundle["raw_sr"]) _sync_device(device) - prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 - - _sync_device(device) - ref_spec_start = time.perf_counter() - spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) - _sync_device(device) - ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0 + bundle_profile = ref_audio_bundle.get("profile", {}) + prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) + ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) + audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) _sync_device(device) tensorize_start = time.perf_counter() @@ -164,8 +192,43 @@ def prepare_request_state( prepare_profile = { "prompt_text_features_ms": prompt_text_features_ms, "text_features_ms": text_features_ms, + "prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)), + "text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)), + "text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)), + "text_feature_pair_ms": text_feature_pair_ms, + "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), + "audio_load_ms": audio_load_ms, + "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), + "audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)), + "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_ms": ref_spec_ms, + "ref_audio_bundle_ms": ref_audio_bundle_ms, "tensorize_ms": tensorize_ms, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, @@ -186,7 +249,7 @@ def prepare_request_state( prompt_semantic=prompt_semantic, refer_spec=(spec_audio, audio_16k), raw_audio=raw_audio, - raw_sr=int(raw_sr), + raw_sr=raw_sr, top_k=spec.top_k, top_p=spec.top_p, temperature=spec.temperature, @@ -409,9 +472,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: return F.pad(mask, (pad_len, 0), value=True) +def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor: + if mask.shape[-1] > target_len: + return mask[:, :, :, -target_len:] + if mask.shape[-1] < target_len: + return _pad_decode_mask_left(mask, target_len) + return mask + + def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + expected_mask_len = running_request.k_cache[0].shape[1] + 1 if running_request.decode_attn_mask is not None: - return running_request.decode_attn_mask + return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len) current_mask_len = running_request.k_cache[0].shape[1] + 1 return torch.zeros( (1, 1, 1, current_mask_len), @@ -481,17 +553,19 @@ def run_prefill_step( real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + request_decode_attn_mask = None + if decode_attn_mask is not None: + request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone() + request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1) + if not request_decode_attn_mask.any().item(): + request_decode_attn_mask = None running_requests.append( T2SRunningRequest( state=state, y_sequence=new_history, prefix_len=prefix_len, - decode_attn_mask=( - None - if decode_attn_mask is None - else decode_attn_mask[batch_index : batch_index + 1].clone() - ), + decode_attn_mask=request_decode_attn_mask, k_cache=request_k_cache, v_cache=request_v_cache, step_idx=1, @@ -603,6 +677,9 @@ def run_decode_step_for_running( batch_index : batch_index + 1, :, :, -current_decode_mask_len: ] next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) + next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1) + if not next_decode_attn_mask.any().item(): + next_decode_attn_mask = None next_running.append( T2SRunningRequest( state=running_request.state, diff --git a/api_v3.py b/api_v3.py index 9d250119..92f9a3b9 100644 --- a/api_v3.py +++ b/api_v3.py @@ -261,8 +261,9 @@ class SchedulerDebugWorker: self.tts = tts self.max_steps = max_steps self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 - self.prepare_lock = threading.Lock() self.condition = threading.Condition() + self.prepare_inflight = 0 + self.prepare_peak_inflight = 0 self.pending_jobs: List[SchedulerPendingJob] = [] self.running_requests: List[T2SRunningRequest] = [] self.job_map: dict[str, SchedulerPendingJob] = {} @@ -282,8 +283,20 @@ class SchedulerDebugWorker: pass def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - with self.prepare_lock: - return prepare_request_state(self.tts, spec) + with self.condition: + self.prepare_inflight += 1 + prepare_inflight_on_enter = self.prepare_inflight + if self.prepare_inflight > self.prepare_peak_inflight: + self.prepare_peak_inflight = self.prepare_inflight + prepare_peak_inflight = self.prepare_peak_inflight + try: + state = prepare_request_state(self.tts, spec) + state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter) + state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight) + return state + finally: + with self.condition: + self.prepare_inflight = max(0, self.prepare_inflight - 1) def submit( self, @@ -363,9 +376,28 @@ class SchedulerDebugWorker: def get_state(self) -> dict: with self.condition: + bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() + ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() + bert_batch_worker = ( + None + if self.tts.prepare_bert_batch_worker is None + else self.tts.prepare_bert_batch_worker.snapshot() + ) + ref_semantic_batch_worker = ( + None + if self.tts.prepare_ref_semantic_batch_worker is None + else self.tts.prepare_ref_semantic_batch_worker.snapshot() + ) return { "pending_jobs": len(self.pending_jobs), "running_requests": len(self.running_requests), + "prepare_inflight": self.prepare_inflight, + "prepare_peak_inflight": self.prepare_peak_inflight, + "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), + "prepare_bert_stage": bert_stage, + "prepare_bert_batch_worker": bert_batch_worker, + "prepare_ref_audio_stage": ref_audio_stage, + "prepare_ref_semantic_batch_worker": ref_semantic_batch_worker, "tracked_jobs": len(self.job_map), "total_submitted": self.total_submitted, "total_finished": self.total_finished, @@ -907,10 +939,36 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): { "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", + "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Batch-Size": str( + int(prepare_profile.get("prompt_semantic_batch_size", 0.0)) + ), "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", + "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + "X-Prepare-Inflight-On-Enter": str( + int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) + ), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), } ) return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) From a45e171ff50a372f1ac1ce567fd95659050b453b Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 21:24:16 +0800 Subject: [PATCH 07/24] Enhance sampling functions in TTS by adding support for previous token masks in logits_to_probs. Implement batch processing for sampling with padded token sequences and contiguous sampling groups. Refactor sampling logic in T2S scheduler to utilize new functionalities, improving efficiency and flexibility in token generation. --- GPT_SoVITS/AR/models/utils.py | 37 ++++- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 170 ++++++++++++++++----- 2 files changed, 165 insertions(+), 42 deletions(-) diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index cc4f24d8..4b564ed8 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync( def logits_to_probs( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[int] = None, @@ -158,13 +159,27 @@ def logits_to_probs( # pdb.set_trace() if previous_tokens is not None and repetition_penalty != 1.0: previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=1, index=previous_tokens) - score = torch.where( - score < 0, - score * repetition_penalty, - score / repetition_penalty, - ) - logits.scatter_(dim=1, index=previous_tokens, src=score) + if previous_token_mask is None: + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=1, index=previous_tokens, src=score) + else: + previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device) + if previous_token_mask.any(): + batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens) + valid_batch_index = batch_index[previous_token_mask] + valid_token_index = previous_tokens[previous_token_mask] + score = logits[valid_batch_index, valid_token_index] + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits[valid_batch_index, valid_token_index] = score if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) @@ -192,9 +207,15 @@ def logits_to_probs( def sample( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, **sampling_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs) + probs = logits_to_probs( + logits=logits, + previous_tokens=previous_tokens, + previous_token_mask=previous_token_mask, + **sampling_kwargs, + ) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index de498573..b7118a72 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import torch import torch.nn.functional as F -from AR.models.utils import make_pad_mask_left, sample +from AR.models.utils import logits_to_probs, make_pad_mask_left, multinomial_sample_one_no_sync, sample def _sync_device(device: Any) -> None: @@ -277,6 +277,90 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: ) +def _pad_token_sequences( + token_sequences: Sequence[torch.LongTensor], +) -> Tuple[torch.LongTensor, torch.BoolTensor]: + if not token_sequences: + raise ValueError("token_sequences 不能为空") + device = token_sequences[0].device + max_len = max(int(sequence.shape[0]) for sequence in token_sequences) + padded = torch.zeros((len(token_sequences), max_len), dtype=token_sequences[0].dtype, device=device) + mask = torch.zeros((len(token_sequences), max_len), dtype=torch.bool, device=device) + for row_index, sequence in enumerate(token_sequences): + seq_len = int(sequence.shape[0]) + padded[row_index, :seq_len] = sequence + mask[row_index, :seq_len] = True + return padded, mask + + +def _sampling_group_key( + top_k: int, + top_p: float, + temperature: float, + repetition_penalty: float, + trim_eos: bool, +) -> Tuple[int, float, float, float, bool]: + return ( + int(top_k), + float(top_p), + float(temperature), + float(repetition_penalty), + bool(trim_eos), + ) + + +def _iter_contiguous_sampling_groups( + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> List[Tuple[Tuple[int, float, float, float, bool], List[int]]]: + groups: List[Tuple[Tuple[int, float, float, float, bool], List[int]]] = [] + if not sampling_keys: + return groups + current_key = sampling_keys[0] + current_indices: List[int] = [0] + for index in range(1, len(sampling_keys)): + key = sampling_keys[index] + if key == current_key: + current_indices.append(index) + continue + groups.append((current_key, current_indices)) + current_key = key + current_indices = [index] + groups.append((current_key, current_indices)) + return groups + + +def _batched_sample_by_group( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> Tuple[List[torch.Tensor], List[int]]: + sampled_list: List[Optional[torch.Tensor]] = [None] * len(histories) + argmax_list: List[Optional[int]] = [None] * len(histories) + for group_key, group_indices in _iter_contiguous_sampling_groups(sampling_keys): + top_k, top_p, temperature, repetition_penalty, trim_eos = group_key + index_tensor = torch.tensor(group_indices, dtype=torch.long, device=logits.device) + group_logits = torch.index_select(logits, dim=0, index=index_tensor) + if trim_eos: + group_logits = group_logits[:, :-1] + group_histories = [histories[index] for index in group_indices] + padded_histories, history_mask = _pad_token_sequences(group_histories) + probs = logits_to_probs( + logits=group_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + argmax_tokens = torch.argmax(group_logits, dim=-1) + for local_index, global_index in enumerate(group_indices): + sampled_list[global_index] = multinomial_sample_one_no_sync(probs[local_index : local_index + 1]) + argmax_list[global_index] = int(argmax_tokens[local_index].item()) + + return [item for item in sampled_list if item is not None], [int(item) for item in argmax_list if item is not None] + + @torch.inference_mode() def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: x_items: List[torch.Tensor] = [] @@ -360,19 +444,26 @@ def _sample_per_request( updated_sequences: List[torch.LongTensor] = [] step_idx = active_batch.step_idx - for batch_index, state in enumerate(active_batch.states): - logits_i = logits[batch_index : batch_index + 1].clone() - current_history = active_batch.y_sequences[batch_index] - sampled = sample( - logits_i, - current_history.unsqueeze(0), + sampling_keys = [ + _sampling_group_key( top_k=state.top_k, top_p=state.top_p, - repetition_penalty=state.repetition_penalty, temperature=state.temperature, - )[0] + repetition_penalty=state.repetition_penalty, + trim_eos=False, + ) + for state in active_batch.states + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) + for batch_index, state in enumerate(active_batch.states): + current_history = active_batch.y_sequences[batch_index] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None @@ -507,25 +598,30 @@ def run_prefill_step( if len(states) == 1 and not decode_attn_mask.any().item(): decode_attn_mask = None logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=True, + ) + for state in states + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) running_requests: List[T2SRunningRequest] = [] finished_items: List[T2SFinishedItem] = [] for batch_index, state in enumerate(states): - logits_i = logits[batch_index : batch_index + 1].clone() - if 0 < 11: - logits_i = logits_i[:, :-1] current_history = active_batch.y_sequences[batch_index] - sampled = sample( - logits_i, - current_history.unsqueeze(0), - top_k=state.top_k, - top_p=state.top_p, - repetition_penalty=state.repetition_penalty, - temperature=state.temperature, - )[0] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) prefix_len = int(active_batch.prefix_lens[batch_index].item()) @@ -624,25 +720,31 @@ def run_decode_step_for_running( batched_decode_attn_mask, ) logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + temperature=running_request.state.temperature, + repetition_penalty=running_request.state.repetition_penalty, + trim_eos=running_request.step_idx < 11, + ) + for running_request in running_requests + ] + histories = [running_request.y_sequence for running_request in running_requests] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=histories, + sampling_keys=sampling_keys, + ) next_running: List[T2SRunningRequest] = [] finished_items: List[T2SFinishedItem] = [] for batch_index, running_request in enumerate(running_requests): current_idx = running_request.step_idx - logits_i = logits[batch_index : batch_index + 1].clone() - if current_idx < 11: - logits_i = logits_i[:, :-1] - sampled = sample( - logits_i, - running_request.y_sequence.unsqueeze(0), - top_k=running_request.state.top_k, - top_p=running_request.state.top_p, - repetition_penalty=running_request.state.repetition_penalty, - temperature=running_request.state.temperature, - )[0] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None From 827d6ea47c5d03aac006597f2f8e41761b284b13 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Tue, 10 Mar 2026 06:58:53 +0800 Subject: [PATCH 08/24] Refactor TTS and scheduler components to enhance text processing and batching capabilities. Introduce PrepareCoordinator for managing text feature preparation asynchronously, and update SchedulerDebugWorker to support new finalize task management. Implement batch processing in PrepareBertBatchWorker with improved admission control and profiling metrics. Add text CPU preprocessing utilities for better text segmentation and normalization. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 160 ++++- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 232 ++++++- .../prepare_bert_batch_worker.py | 181 +++++- .../TTS_infer_pack/prepare_coordinator.py | 294 +++++++++ GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 489 +++++++++++---- .../TTS_infer_pack/text_cpu_preprocess.py | 100 ++++ GPT_SoVITS/module/models.py | 62 ++ api_v3.py | 565 +++++++++++++++--- 8 files changed, 1811 insertions(+), 272 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py create mode 100644 GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c259662..d475b804 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,26 +1,27 @@ import gc +import concurrent.futures import math import os import random import sys import time import traceback -from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -import torchaudio -from tqdm import tqdm - now_dir = os.getcwd() sys.path.append(now_dir) -import os from typing import List, Tuple, Union +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + import ffmpeg import librosa import numpy as np import torch import torch.nn.functional as F +import torchaudio import yaml from AR.models.t2s_lightning_module import Text2SemanticLightningModule from BigVGAN.bigvgan import BigVGAN @@ -30,6 +31,7 @@ from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator from peft import LoraConfig, get_peft_model from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from transformers import AutoModelForMaskedLM, AutoTokenizer +from tqdm import tqdm from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list @@ -449,20 +451,21 @@ class TTS: "overlapped_len": None, } self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) - self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None - default_text_cpu_workers = 16 self.prepare_text_cpu_workers = max( 0, - int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))), + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")), ) - self.prepare_text_cpu_executor = None - if self.prepare_text_cpu_workers > 0: - self.prepare_text_cpu_executor = ThreadPoolExecutor( + self.prepare_text_cpu_executor = ( + concurrent.futures.ThreadPoolExecutor( max_workers=self.prepare_text_cpu_workers, thread_name_prefix="prepare-text-cpu", ) + if self.prepare_text_cpu_workers > 0 + else None + ) self._init_models() @@ -475,6 +478,20 @@ class TTS: batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_BERT_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_ADMISSION_POLL_MS", "1")), + high_pressure_pending_threshold=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "0") + ), + high_pressure_batch_window_ms=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_BATCH_WINDOW_MS", "1") + ), + high_pressure_max_batch_items=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_ITEMS", "32") + ), + high_pressure_max_batch_tokens=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_TOKENS", "8192") + ), ) if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") @@ -830,13 +847,23 @@ class TTS: return codes[0, 0].to(self.configs.device) @torch.inference_mode() - def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + cpu_prepare_start = time.perf_counter() wav16k = prepare_prompt_semantic_wav16k( raw_audio=raw_audio, raw_sr=raw_sr, zero_wav_samples=int(self.configs.sampling_rate * 0.3), ) - return self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + forward_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + return prompt_semantic, cpu_prepare_ms, forward_ms + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + prompt_semantic, _, _ = self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + return prompt_semantic def extract_prompt_semantic(self, ref_wav_path: str): raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) @@ -887,7 +914,9 @@ class TTS: if self.prepare_ref_semantic_batch_worker is None: with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() - prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = ( + self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + ) prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 ref_spec_start = time.perf_counter() refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] @@ -897,8 +926,8 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), - "prompt_semantic_cpu_prepare_ms": 0.0, - "prompt_semantic_forward_ms": prompt_semantic_ms, + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), + "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), "prompt_semantic_scatter_ms": 0.0, "prompt_semantic_stage_slots": float(limiter_stats["slots"]), "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), @@ -1012,6 +1041,32 @@ class TTS: text, language, self.configs.version, profile=profile ) + def prepare_text_segments(self, text: str, language: str): + return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version) + + def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None): + return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile) + + async def build_text_features_from_segments_async(self, prepared_segments, profile: dict | None = None): + return await self.text_preprocessor.build_phones_and_bert_from_segments_async( + prepared_segments, + profile=profile, + ) + + async def build_text_feature_pair_from_segments_async( + self, + prompt_segments, + target_segments, + prompt_profile: dict | None = None, + target_profile: dict | None = None, + ): + return await self.text_preprocessor.build_phones_and_bert_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path @@ -2011,6 +2066,79 @@ class TTS: sample_steps=sample_steps, ) + @torch.inference_mode() + def synthesize_audio_requests_local_batched( + self, + semantic_tokens_list: List[torch.Tensor], + phones_list: List[torch.Tensor], + refer_specs: List[tuple], + speeds: List[float] | None = None, + sample_steps_list: List[int] | None = None, + ) -> List[torch.Tensor]: + batch_size = len(semantic_tokens_list) + if batch_size == 0: + return [] + if len(phones_list) != batch_size or len(refer_specs) != batch_size: + raise ValueError("batched request-local synthesis 输入长度不一致") + if speeds is None: + speeds = [1.0] * batch_size + if sample_steps_list is None: + sample_steps_list = [32] * batch_size + if len(speeds) != batch_size or len(sample_steps_list) != batch_size: + raise ValueError("batched request-local synthesis 参数长度不一致") + first_speed = float(speeds[0]) + first_sample_steps = int(sample_steps_list[0]) + if any(abs(float(item) - first_speed) > 1e-6 for item in speeds): + raise ValueError("batched request-local synthesis 目前要求 speed 一致") + if any(int(item) != first_sample_steps for item in sample_steps_list): + raise ValueError("batched request-local synthesis 目前要求 sample_steps 一致") + if self.configs.use_vocoder: + raise NotImplementedError("request-local batched VITS synthesis 暂不支持 vocoder 模型") + + device = self.configs.device + max_semantic_len = max(int(item.shape[-1]) for item in semantic_tokens_list) + max_phone_len = max(int(item.shape[-1]) for item in phones_list) + semantic_batch = torch.zeros((1, batch_size, max_semantic_len), dtype=torch.long, device=device) + phone_batch = torch.zeros((batch_size, max_phone_len), dtype=torch.long, device=device) + semantic_lengths = [] + phone_lengths = [] + refer_audio_specs: List[torch.Tensor] = [] + sv_emb_batch = None + sv_emb_list: List[torch.Tensor] = [] + + for batch_index, semantic_tokens in enumerate(semantic_tokens_list): + semantic_len = int(semantic_tokens.shape[-1]) + phone_len = int(phones_list[batch_index].shape[-1]) + semantic_batch[0, batch_index, :semantic_len] = semantic_tokens.to(device=device, dtype=torch.long) + phone_batch[batch_index, :phone_len] = phones_list[batch_index].to(device=device, dtype=torch.long) + semantic_lengths.append(semantic_len) + phone_lengths.append(phone_len) + + refer_audio_spec, audio_tensor = refer_specs[batch_index] + refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device)) + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local batched synthesis 缺少 16k 参考音频")) + sv_emb_list.append(self.sv_model.compute_embedding3(audio_tensor).to(device)) + + if self.is_v2pro: + sv_emb_batch = torch.cat(sv_emb_list, dim=0) + + audio_batch, audio_lengths = self.vits_model.decode_batched_request_local( + codes=semantic_batch, + code_lengths=torch.LongTensor(semantic_lengths).to(device), + text=phone_batch, + text_lengths=torch.LongTensor(phone_lengths).to(device), + refer_list=refer_audio_specs, + speed=first_speed, + sv_emb=sv_emb_batch, + ) + audios: List[torch.Tensor] = [] + for batch_index in range(batch_size): + audio_len = int(audio_lengths[batch_index].item()) + audios.append(audio_batch[batch_index, 0, :audio_len].detach()) + return audios + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 15b3c322..6bee49be 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,8 +1,10 @@ +import asyncio import os import sys import threading import time from contextlib import contextmanager +from dataclasses import dataclass from tqdm import tqdm @@ -13,12 +15,13 @@ import re import torch from text.LangSegmenter import LangSegmenter from text import chinese -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload from tools.i18n.i18n import I18nAuto, scan_language_list @@ -92,6 +95,14 @@ class StageLimiter: } +@dataclass +class PreparedTextSegment: + language: str + phones: List[int] + word2ph: Optional[List[int]] + norm_text: str + + class TextPreprocessor: def __init__( self, @@ -149,7 +160,7 @@ class TextPreprocessor: # 解决输入目标文本的空行导致报错的问题 if len(text.strip()) == 0: continue - if not re.sub("\W+", "", text): + if not re.sub(r"\W+", "", text): # 检测一下,如果是纯符号,就跳过。 continue if text[-1] not in splits: @@ -168,7 +179,8 @@ class TextPreprocessor: def segment_and_extract_feature_for_text( self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version, profile=profile) + prepared_segments = self.preprocess_text_segments(text, language, version) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: textlist = [] @@ -223,24 +235,49 @@ class TextPreprocessor: def get_phones_and_bert( self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None ): - text = re.sub(r' {2,}', ' ', text) - textlist, langlist = self._split_text_by_language(text, language) - phones_list = [] - bert_list = [] - norm_text_list = [] - for segment_text, segment_lang in zip(textlist, langlist): - phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile) - phones_list.append(phones) - norm_text_list.append(norm_text) + prepared_segments = self.preprocess_text_segments(text, language, version, final=final) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) + + def preprocess_text_segments( + self, + text: str, + language: str, + version: str, + final: bool = False, + ) -> List[PreparedTextSegment]: + payloads = preprocess_text_segments_payload(text, language, version, final=final) + return [ + PreparedTextSegment( + language=str(payload["language"]), + phones=list(payload["phones"]), + word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]), + norm_text=str(payload["norm_text"]), + ) + for payload in payloads + ] + + def build_phones_and_bert_from_segments( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor] = [] + norm_text_list: List[str] = [] + for segment in prepared_segments: + bert = self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) bert_list.append(bert) bert = torch.cat(bert_list, dim=1) phones = sum(phones_list, []) norm_text = "".join(norm_text_list) - - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile) - return phones, bert, norm_text def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: @@ -253,21 +290,41 @@ class TextPreprocessor: return profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None: + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0)) + self._accumulate_profile( + profile, + "bert_batch_collect_wait_ms", + worker_profile.get("bert_batch_collect_wait_ms", 0.0), + ) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + self._update_profile_peak( + profile, + "bert_pending_depth_on_enqueue_peak", + worker_profile.get("bert_pending_depth_on_enqueue", 0.0), + ) + self._update_profile_peak( + profile, + "bert_pending_depth_on_collect_peak", + worker_profile.get("bert_pending_depth_on_collect", 0.0), + ) + self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0)) + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: if self.bert_batch_worker is not None: feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) - self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) - self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) - self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) - self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) - self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) - self._update_profile_peak( - profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0) - ) - self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) - self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) - if profile is not None: - profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + self._merge_bert_worker_profile(profile, worker_profile) return feature limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} @@ -310,9 +367,18 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None): + def get_bert_inf( + self, + phones: list, + word2ph: Optional[list], + norm_text: str, + language: str, + profile: Dict | None = None, + ): language = language.replace("all_", "") if language == "zh": + if word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( @@ -322,6 +388,112 @@ class TextPreprocessor: return feature + async def build_phones_and_bert_from_segments_async( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + segment_jobs = self._build_async_segment_jobs(prepared_segments, profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(item_profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(segment_jobs) + + def _build_async_segment_jobs( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None, + ) -> Dict[str, List]: + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor | None] = [] + norm_text_list: List[str] = [] + + for segment in prepared_segments: + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) + segment_language = segment.language.replace("all_", "") + if segment_language == "zh" and self.bert_batch_worker is not None: + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + bert_list.append(None) + continue + bert_list.append( + self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + ) + return { + "phones_list": phones_list, + "bert_list": bert_list, + "norm_text_list": norm_text_list, + } + + @staticmethod + def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]: + bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1) + phones = sum(segment_jobs["phones_list"], []) + norm_text = "".join(segment_jobs["norm_text_list"]) + return phones, bert, norm_text + + async def build_phones_and_bert_pair_from_segments_async( + self, + prompt_segments: List[PreparedTextSegment], + target_segments: List[PreparedTextSegment], + prompt_profile: Dict | None = None, + target_profile: Dict | None = None, + ) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]: + prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile) + target_jobs = self._build_async_segment_jobs(target_segments, target_profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + + for segment_jobs, prepared_segments, profile in ( + (prompt_jobs, prompt_segments, prompt_profile), + (target_jobs, target_segments, target_profile), + ): + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs) + def filter_text(self, texts): _text = [] if all(text in [None, " ", "\n", ""] for text in texts): diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py index b1ede3d8..1ac77faa 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -1,3 +1,4 @@ +import asyncio import threading import time import uuid @@ -14,7 +15,12 @@ class BertFeatureTask: word2ph: List[int] task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None result_feature: torch.Tensor | None = None error: Exception | None = None profile: Dict[str, float] = field(default_factory=dict) @@ -30,14 +36,37 @@ class PrepareBertBatchWorker: batch_window_ms: int = 5, max_batch_items: int = 16, max_batch_tokens: int = 4096, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + high_pressure_pending_threshold: int = 0, + high_pressure_batch_window_ms: int | None = None, + high_pressure_max_batch_items: int | None = None, + high_pressure_max_batch_tokens: int | None = None, ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device self.stage_limiter = stage_limiter - self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.batch_window_ms = max(0, int(batch_window_ms)) + self.batch_window_s = float(self.batch_window_ms) / 1000.0 self.max_batch_items = max(1, int(max_batch_items)) self.max_batch_tokens = max(16, int(max_batch_tokens)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + + self.high_pressure_pending_threshold = max( + 0, + int(high_pressure_pending_threshold) + if int(high_pressure_pending_threshold) > 0 + else max(self.max_batch_items * 2, 32), + ) + hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms) + hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items) + hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens) + self.high_pressure_batch_window_ms = max(0, hp_window_ms) + self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0 + self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items) + self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens) self.condition = threading.Condition() self.pending_tasks: Deque[BertFeatureTask] = deque() @@ -47,26 +76,70 @@ class PrepareBertBatchWorker: self.total_batches = 0 self.active_batch_size = 0 self.active_batch_peak = 0 + self.active_batch_tokens = 0 + self.active_batch_tokens_peak = 0 + self.high_pressure_batches = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) self.worker_thread.start() def _estimate_task_tokens(self, task: BertFeatureTask) -> int: return max(1, len(task.norm_text) + 2) + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks + + def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + def _enqueue_task(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + with self.condition: + while not self._can_enqueue_locked(): + self.condition.wait(timeout=self.admission_poll_s) + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + + async def _enqueue_task_async(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + while True: + with self.condition: + if self._can_enqueue_locked(): + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + return + await asyncio.sleep(self.admission_poll_s) + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) - with self.condition: - self.pending_tasks.append(task) - self.total_submitted += 1 - if len(self.pending_tasks) > self.pending_peak: - self.pending_peak = len(self.pending_tasks) - self.condition.notify_all() + self._enqueue_task(task) task.done_event.wait() if task.error is not None: raise task.error assert task.result_feature is not None return task.result_feature, dict(task.profile) + async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = BertFeatureTask( + norm_text=str(norm_text), + word2ph=list(word2ph), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + def snapshot(self) -> Dict[str, int]: with self.condition: return { @@ -77,21 +150,57 @@ class PrepareBertBatchWorker: "total_batches": self.total_batches, "active_batch_size": self.active_batch_size, "active_batch_peak": self.active_batch_peak, + "active_batch_tokens": self.active_batch_tokens, + "active_batch_tokens_peak": self.active_batch_tokens_peak, "batch_window_ms": int(self.batch_window_s * 1000.0), "max_batch_items": self.max_batch_items, "max_batch_tokens": self.max_batch_tokens, + "max_pending_tasks": self.max_pending_tasks, + "high_pressure_pending_threshold": self.high_pressure_pending_threshold, + "high_pressure_batch_window_ms": self.high_pressure_batch_window_ms, + "high_pressure_max_batch_items": self.high_pressure_max_batch_items, + "high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens, + "high_pressure_batches": self.high_pressure_batches, + "admission_wait_total_ms": self.admission_wait_total_ms, + "admission_wait_peak_ms": self.admission_wait_peak_ms, } - def _collect_batch(self) -> List[BertFeatureTask]: + def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]: + pending_depth = len(self.pending_tasks) + use_high_pressure = ( + self.high_pressure_pending_threshold > 0 + and pending_depth >= self.high_pressure_pending_threshold + ) + if use_high_pressure: + return ( + self.high_pressure_batch_window_s, + self.high_pressure_max_batch_items, + self.high_pressure_max_batch_tokens, + True, + pending_depth, + ) + return ( + self.batch_window_s, + self.max_batch_items, + self.max_batch_tokens, + False, + pending_depth, + ) + + def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]: with self.condition: while not self.pending_tasks: self.condition.wait() + collect_started = time.perf_counter() + batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = ( + self._select_batch_policy_locked() + ) batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] batch_tokens = self._estimate_task_tokens(batch[0]) - deadline = time.perf_counter() + self.batch_window_s + deadline = time.perf_counter() + batch_window_s - while len(batch) < self.max_batch_items: + while len(batch) < max_batch_items: remaining = deadline - time.perf_counter() if remaining <= 0: break @@ -100,26 +209,39 @@ class PrepareBertBatchWorker: continue next_task = self.pending_tasks[0] next_tokens = self._estimate_task_tokens(next_task) - if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens: + if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens: break batch.append(self.pending_tasks.popleft()) batch_tokens += next_tokens self.active_batch_size = len(batch) + self.active_batch_tokens = batch_tokens if self.active_batch_size > self.active_batch_peak: self.active_batch_peak = self.active_batch_size - return batch + if self.active_batch_tokens > self.active_batch_tokens_peak: + self.active_batch_tokens_peak = self.active_batch_tokens + if use_high_pressure: + self.high_pressure_batches += 1 + return batch, { + "collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0, + "batch_tokens": float(batch_tokens), + "pending_depth_on_collect": float(pending_depth_on_collect), + "high_pressure_mode": 1.0 if use_high_pressure else 0.0, + "batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms), + } def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: with self.condition: self.active_batch_size = 0 + self.active_batch_tokens = 0 self.total_batches += 1 self.total_finished += len(batch) + self.condition.notify_all() - def _run_batch(self, batch: List[BertFeatureTask]) -> None: + def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None: batch_started = time.perf_counter() texts = [task.norm_text for task in batch] - batch_tokens = sum(self._estimate_task_tokens(task) for task in batch) + batch_tokens = int(batch_meta["batch_tokens"]) limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} if self.stage_limiter is None: @@ -167,6 +289,9 @@ class PrepareBertBatchWorker: task.result_feature = torch.cat(phone_level_feature, dim=0).T task.profile = { "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_admission_wait_ms": float(task.admission_wait_ms), + "bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0), + "bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]), "bert_forward_ms": float(forward_ms), "bert_tokenize_ms": float(tokenize_ms), "bert_scatter_ms": 0.0, @@ -175,6 +300,10 @@ class PrepareBertBatchWorker: "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), "bert_batch_size": float(len(batch)), "bert_batch_tokens": float(batch_tokens), + "bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]), + "bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]), + "bert_batch_window_ms": float(batch_meta["batch_window_ms"]), } except Exception as exc: # noqa: PERF203 task.error = exc @@ -183,15 +312,35 @@ class PrepareBertBatchWorker: if task.result_feature is not None: task.profile["bert_scatter_ms"] = float(scatter_ms) task.done_event.set() + self._notify_done_future(task) + + @staticmethod + def _resolve_done_future(task: BertFeatureTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_feature is not None + task.done_future.set_result((task.result_feature, dict(task.profile))) + + def _notify_done_future(self, task: BertFeatureTask) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass def _run_loop(self) -> None: while True: - batch = self._collect_batch() + batch, batch_meta = self._collect_batch() try: - self._run_batch(batch) + self._run_batch(batch, batch_meta) except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc task.done_event.set() + self._notify_done_future(task) finally: self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py new file mode 100644 index 00000000..1fdf95c5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import os +import threading +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + PreparedTextFeatures, + SchedulerRequestSpec, + T2SRequestState, + build_request_state_from_parts, + normalize_sentence, +) + + +@dataclass +class ProfiledResult: + result: Any + submit_at: float + started_at: float + finished_at: float + + @property + def queue_ms(self) -> float: + return max(0.0, (self.started_at - self.submit_at) * 1000.0) + + @property + def run_ms(self) -> float: + return max(0.0, (self.finished_at - self.started_at) * 1000.0) + + +class PrepareCoordinator: + def __init__(self, tts: Any): + self.tts = tts + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + self.use_async_text_feature_path = bool( + getattr(tts, "prepare_bert_batch_worker", None) is not None + and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0" + ) + self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0"))) + self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None + self.text_feature_workers = 0 + self.text_feature_executor = None + if not self.use_async_text_feature_path: + text_feature_default_workers = max(1, int(getattr(tts, "prepare_text_cpu_workers", 16) or 16)) + self.text_feature_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_WORKERS", str(text_feature_default_workers))), + ) + self.text_feature_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.text_feature_workers, + thread_name_prefix="prepare-text-feature", + ) + ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) + self.ref_audio_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_REF_ASYNC_WORKERS", str(ref_audio_default_workers))), + ) + self.ref_audio_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.ref_audio_workers, + thread_name_prefix="prepare-ref-audio", + ) + + def _mark_enter(self) -> Tuple[int, int]: + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + return current_inflight, self.peak_inflight + + def _mark_leave(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "inflight": int(self.inflight), + "peak_inflight": int(self.peak_inflight), + "max_inflight": int(self.max_inflight), + "text_feature_workers": int(self.text_feature_workers), + "ref_audio_workers": int(self.ref_audio_workers), + } + + @staticmethod + def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult: + started_at = time.perf_counter() + result = fn(*args) + finished_at = time.perf_counter() + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=float(started_at), + finished_at=float(finished_at), + ) + + def _prepare_text_cpu(self, text: str, language: str): + return self.tts.prepare_text_segments(text, language) + + def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures: + profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)} + branch_start = time.perf_counter() + phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile) + total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0) + profile["bert_total_ms"] = max(0.0, total_ms - float(cpu_run_ms)) + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=total_ms, + cpu_preprocess_ms=float(cpu_run_ms), + ) + + async def _run_on_executor(self, executor, fn, *args) -> ProfiledResult: + loop = asyncio.get_running_loop() + submit_at = time.perf_counter() + return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) + + async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + executor = getattr(self.tts, "prepare_text_cpu_executor", None) + if executor is None: + submit_at = time.perf_counter() + return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) + return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + + async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult: + return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms) + + @staticmethod + def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: + return float( + profile.get("bert_wait_ms", 0.0) + + profile.get("bert_tokenize_ms", 0.0) + + profile.get("bert_forward_ms", 0.0) + + profile.get("bert_scatter_ms", 0.0) + ) + + async def _run_text_feature_pair_stage( + self, + prompt_segments, + target_segments, + prompt_cpu_run_ms: float, + target_cpu_run_ms: float, + ) -> tuple[ProfiledResult, ProfiledResult]: + if self.text_feature_executor is not None: + prompt_feature_task = asyncio.create_task( + self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms) + ) + target_feature_task = asyncio.create_task( + self._run_text_feature_stage(target_segments, None, target_cpu_run_ms) + ) + return await asyncio.gather(prompt_feature_task, target_feature_task) + + prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} + submit_at = time.perf_counter() + started_at = float(submit_at) + prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + finished_at = time.perf_counter() + + prompt_result = PreparedTextFeatures( + phones=prompt_result_raw[0], + bert_features=prompt_result_raw[1], + norm_text=prompt_result_raw[2], + profile=prompt_profile, + total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), + cpu_preprocess_ms=float(prompt_cpu_run_ms), + ) + target_result = PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), + ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > prompt_profiled.finished_at: + prompt_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(prompt_profile), + (finished_at - submit_at) * 1000.0, + ) + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: + return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + admission_start = time.perf_counter() + if self._inflight_semaphore is not None: + await self._inflight_semaphore.acquire() + prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0) + current_inflight, peak_inflight = self._mark_enter() + prepare_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + try: + text_pair_start = time.perf_counter() + prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang)) + target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang)) + ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(spec.ref_audio_path))) + prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task) + text_feature_pair_task = asyncio.create_task( + self._run_text_feature_pair_stage( + prompt_cpu_profiled.result, + target_cpu_profiled.result, + prompt_cpu_profiled.run_ms, + target_cpu_profiled.run_ms, + ) + ) + (prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather( + text_feature_pair_task, + ref_audio_task, + ) + text_pair_end = time.perf_counter() + state = build_request_state_from_parts( + tts=self.tts, + spec=spec, + prompt_text=prompt_text, + text=text, + prompt_result=prompt_feature_profiled.result, + target_result=target_feature_profiled.result, + ref_audio_bundle=ref_audio_profiled.result, + prepare_start=prepare_start, + prepare_sync_start=prepare_start, + profile_overrides={ + "executor_queue_ms": max(0.0, (prepare_start - prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": prepare_admission_wait_ms, + "executor_run_wall_ms": max(0.0, (time.perf_counter() - prepare_start) * 1000.0), + "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": 0.0, + "prompt_text_parallel_future_finish_after_submit_ms": 0.0, + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "prompt_text_cpu_queue_ms": prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": prompt_cpu_profiled.run_ms, + "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, + "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, + "text_cpu_queue_ms": target_cpu_profiled.queue_ms, + "text_cpu_run_ms": target_cpu_profiled.run_ms, + "text_feature_queue_ms": target_feature_profiled.queue_ms, + "text_feature_run_ms": target_feature_profiled.run_ms, + "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, + "ref_audio_task_run_ms": ref_audio_profiled.run_ms, + "worker_prepare_inflight_on_enter": float(current_inflight), + "worker_prepare_peak_inflight": float(peak_inflight), + }, + ) + prepare_exec_finished_at = time.perf_counter() + state.prepare_profile["executor_run_wall_ms"] = max( + 0.0, (prepare_exec_finished_at - prepare_start) * 1000.0 + ) + return state, prepare_start, prepare_exec_finished_at + finally: + self._mark_leave() + if self._inflight_semaphore is not None: + self._inflight_semaphore.release() diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index b7118a72..8aabd286 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,6 +1,5 @@ from __future__ import annotations -from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path import time @@ -89,20 +88,31 @@ class T2SFinishedItem: class T2SActiveBatch: request_ids: List[str] states: List[T2SRequestState] - x: torch.Tensor - x_lens: torch.LongTensor + x: Optional[torch.Tensor] + x_lens: Optional[torch.LongTensor] y_sequences: List[torch.LongTensor] prefix_lens: torch.LongTensor xy_pos: torch.Tensor - key_padding_mask: torch.Tensor - prefill_attn_mask: torch.Tensor + key_padding_mask: Optional[torch.Tensor] + prefill_attn_mask: Optional[torch.Tensor] decode_attn_mask: Optional[torch.Tensor] k_cache: Optional[List[torch.Tensor]] v_cache: Optional[List[torch.Tensor]] - step_idx: int + kv_lens: Optional[torch.LongTensor] + step_indices: torch.LongTensor prefill_done: bool +@dataclass +class PreparedTextFeatures: + phones: List[int] + bert_features: torch.Tensor + norm_text: str + profile: Dict[str, float] + total_ms: float + cpu_preprocess_ms: float + + def normalize_sentence(text: str, language: str) -> str: text = text.strip("\n").strip() if not text: @@ -113,105 +123,125 @@ def normalize_sentence(text: str, language: str) -> str: @torch.inference_mode() -def prepare_request_state( +def prepare_text_features( + tts: Any, + text: str, + language: str, +) -> PreparedTextFeatures: + device = tts.configs.device + profile: Dict[str, float] = {} + branch_start = time.perf_counter() + _sync_device(device) + cpu_start = time.perf_counter() + prepared_segments = tts.prepare_text_segments(text, language) + _sync_device(device) + cpu_preprocess_ms = (time.perf_counter() - cpu_start) * 1000.0 + profile["cpu_preprocess_ms"] = float(cpu_preprocess_ms) + bert_start = time.perf_counter() + phones, bert_features, norm_text = tts.build_text_features_from_segments(prepared_segments, profile=profile) + _sync_device(device) + profile["bert_total_ms"] = (time.perf_counter() - bert_start) * 1000.0 + total_ms = (time.perf_counter() - branch_start) * 1000.0 + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=float(total_ms), + cpu_preprocess_ms=float(cpu_preprocess_ms), + ) + + +@torch.inference_mode() +def build_request_state_from_parts( tts: Any, spec: SchedulerRequestSpec, + prompt_text: str, + text: str, + prompt_result: PreparedTextFeatures, + target_result: PreparedTextFeatures, + ref_audio_bundle: Dict[str, Any], + prepare_start: float, + prepare_sync_start: float, + profile_overrides: Optional[Dict[str, float]] = None, ) -> T2SRequestState: device = tts.configs.device - prepare_start = time.perf_counter() _sync_device(device) - prepare_sync_start = time.perf_counter() - prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) - text = spec.text.strip("\n") - - prompt_text_profile: Dict[str, float] = {} - text_features_profile: Dict[str, float] = {} - text_feature_pair_start = time.perf_counter() - prompt_future: Future | None = None - - def _extract_prompt_features(): - _sync_device(device) - prompt_start = time.perf_counter() - result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile) - _sync_device(device) - return result, (time.perf_counter() - prompt_start) * 1000.0 - - if getattr(tts, "prepare_text_cpu_executor", None) is not None: - prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features) - - _sync_device(device) - text_features_start = time.perf_counter() - phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile) - _sync_device(device) - text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 - - if prompt_future is None: - _sync_device(device) - prompt_text_features_start = time.perf_counter() - prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features( - prompt_text, spec.prompt_lang, profile=prompt_text_profile - ) - _sync_device(device) - prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 - prompt_text_profile["parallel_future_wait_ms"] = 0.0 - else: - prompt_wait_start = time.perf_counter() - (prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result() - prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0 - - text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0 - if phones is None: - raise ValueError(f"{spec.request_id} text preprocessing returned no phones") - - _sync_device(device) - ref_audio_bundle_start = time.perf_counter() - ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) + bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic = ref_audio_bundle["prompt_semantic"].long() spec_audio, audio_16k = ref_audio_bundle["refer_spec"] raw_audio = ref_audio_bundle["raw_audio"] raw_sr = int(ref_audio_bundle["raw_sr"]) - _sync_device(device) - ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0 - bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) _sync_device(device) tensorize_start = time.perf_counter() - phones_tensor = torch.LongTensor(phones).to(tts.configs.device) - prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device) - all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device) - all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to( + phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device) + all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to( dtype=tts.precision, device=tts.configs.device ) _sync_device(device) tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 - _sync_device(device) prepare_profile = { - "prompt_text_features_ms": prompt_text_features_ms, - "text_features_ms": text_features_ms, - "prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)), - "prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)), - "prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)), - "prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)), - "prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)), - "prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)), - "prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)), - "prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)), - "prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)), - "prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)), - "text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)), - "text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)), - "text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)), - "text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)), - "text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)), - "text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)), - "text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)), - "text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)), - "text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)), - "text_feature_pair_ms": text_feature_pair_ms, + "prompt_text_features_ms": float(prompt_result.total_ms), + "text_features_ms": float(target_result.total_ms), + "prompt_text_cpu_preprocess_ms": float(prompt_result.cpu_preprocess_ms), + "text_cpu_preprocess_ms": float(target_result.cpu_preprocess_ms), + "prompt_text_bert_wait_ms": float(prompt_result.profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_admission_wait_ms": float(prompt_result.profile.get("bert_admission_wait_ms", 0.0)), + "prompt_text_bert_queue_wait_ms": float(prompt_result.profile.get("bert_queue_wait_ms", 0.0)), + "prompt_text_bert_batch_collect_wait_ms": float(prompt_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_result.profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_result.profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_result.profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_result.profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_result.profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_result.profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_result.profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_result.profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_bert_pending_depth_on_enqueue_peak": float( + prompt_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "prompt_text_bert_pending_depth_on_collect_peak": float( + prompt_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "prompt_text_bert_high_pressure_mode_peak": float( + prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0) + ), + "prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_finish_after_submit_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "text_bert_wait_ms": float(target_result.profile.get("bert_wait_ms", 0.0)), + "text_bert_admission_wait_ms": float(target_result.profile.get("bert_admission_wait_ms", 0.0)), + "text_bert_queue_wait_ms": float(target_result.profile.get("bert_queue_wait_ms", 0.0)), + "text_bert_batch_collect_wait_ms": float(target_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "text_bert_forward_ms": float(target_result.profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(target_result.profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(target_result.profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(target_result.profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(target_result.profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(target_result.profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(target_result.profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(target_result.profile.get("bert_batch_tokens_peak", 0.0)), + "text_bert_pending_depth_on_enqueue_peak": float( + target_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "text_bert_pending_depth_on_collect_peak": float( + target_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)), + "text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)), + "text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)), "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), "audio_load_ms": audio_load_ms, "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), @@ -233,6 +263,8 @@ def prepare_request_state( "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, } + if profile_overrides: + prepare_profile.update({key: float(value) for key, value in profile_overrides.items()}) return T2SRequestState( request_id=spec.request_id, ref_audio_path=spec.ref_audio_path, @@ -240,8 +272,8 @@ def prepare_request_state( prompt_lang=spec.prompt_lang, text=text, text_lang=spec.text_lang, - norm_prompt_text=prompt_norm_text, - norm_text=norm_text, + norm_prompt_text=prompt_result.norm_text, + norm_text=target_result.norm_text, phones=phones_tensor, prompt_phones=prompt_phones_tensor, all_phones=all_phones, @@ -260,6 +292,33 @@ def prepare_request_state( ) +@torch.inference_mode() +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + prepare_start = time.perf_counter() + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang) + target_result = prepare_text_features(tts, text, spec.text_lang) + if target_result.phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + return build_request_state_from_parts( + tts=tts, + spec=spec, + prompt_text=prompt_text, + text=text, + prompt_result=prompt_result, + target_result=target_result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=prepare_start, + prepare_sync_start=prepare_sync_start, + ) + + def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: if hidden.shape[0] >= target_len: return hidden @@ -417,7 +476,8 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct decode_attn_mask=None, k_cache=None, v_cache=None, - step_idx=0, + kv_lens=None, + step_indices=torch.zeros((len(states),), dtype=torch.long, device=device), prefill_done=False, ) @@ -433,6 +493,64 @@ def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> to ) +def _compact_cache_to_kv_lens( + cache: torch.Tensor, + kv_lens: torch.LongTensor, +) -> torch.Tensor: + target_len = int(kv_lens.max().item()) + if cache.shape[1] == target_len and torch.all(kv_lens == target_len).item(): + return cache + compacted = cache.new_zeros((cache.shape[0], target_len, cache.shape[2])) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + if kv_len <= 0: + continue + compacted[batch_index, -kv_len:, :] = cache[batch_index, -kv_len:, :] + return compacted + + +def _compact_decode_mask_to_kv_lens( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + target_len = int(kv_lens.max().item()) + 1 + if decode_attn_mask is None: + return None + if decode_attn_mask.shape[-1] == target_len and torch.all(kv_lens + 1 == target_len).item(): + return decode_attn_mask + compacted = torch.ones( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + compacted[batch_index, :, :, -current_len:] = decode_attn_mask[batch_index, :, :, -current_len:] + if not compacted.any().item(): + return None + return compacted + + +def _advance_decode_mask( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + if decode_attn_mask is None: + return None + target_len = int(kv_lens.max().item()) + 2 + advanced = torch.zeros( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + next_mask = F.pad(decode_attn_mask[batch_index : batch_index + 1, :, :, -current_len:], (0, 1), value=False) + advanced[batch_index : batch_index + 1, :, :, -next_mask.shape[-1] :] = next_mask + if not advanced.any().item(): + return None + return advanced + + def _sample_per_request( model: Any, active_batch: T2SActiveBatch, @@ -443,16 +561,15 @@ def _sample_per_request( keep_indices: List[int] = [] updated_sequences: List[torch.LongTensor] = [] - step_idx = active_batch.step_idx sampling_keys = [ _sampling_group_key( top_k=state.top_k, top_p=state.top_p, temperature=state.temperature, repetition_penalty=state.repetition_penalty, - trim_eos=False, + trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, ) - for state in active_batch.states + for batch_index, state in enumerate(active_batch.states) ] sampled_items, argmax_tokens = _batched_sample_by_group( logits=logits, @@ -460,6 +577,7 @@ def _sample_per_request( sampling_keys=sampling_keys, ) for batch_index, state in enumerate(active_batch.states): + step_index = int(active_batch.step_indices[batch_index].item()) current_history = active_batch.y_sequences[batch_index] sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) @@ -469,7 +587,7 @@ def _sample_per_request( finish_reason: Optional[str] = None if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: finish_reason = "early_stop" - elif step_idx + 1 >= max_steps: + elif step_index + 1 >= max_steps: finish_reason = "max_step" elif sampled_token == model.EOS: finish_reason = "eos_sample" @@ -482,7 +600,7 @@ def _sample_per_request( T2SFinishedItem( request_id=state.request_id, semantic_tokens=new_history[prefix_len:-1].clone(), - finish_idx=step_idx, + finish_idx=step_index, finish_reason=finish_reason, ) ) @@ -493,30 +611,48 @@ def _sample_per_request( return finished_items, keep_indices, updated_sequences +@torch.inference_mode() def decode_one_step( model: Any, active_batch: T2SActiveBatch, max_steps: int, ) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: - if not active_batch.prefill_done: + was_prefill = not active_batch.prefill_done + if was_prefill: + if active_batch.prefill_attn_mask is None or active_batch.key_padding_mask is None: + raise ValueError("prefill 阶段缺少必要 mask") xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( active_batch.xy_pos, active_batch.prefill_attn_mask, None ) + active_batch.kv_lens = active_batch.x_lens + active_batch.prefix_lens active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("prefill 阶段未生成完整 KV cache") + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(active_batch.decode_attn_mask, active_batch.kv_lens) + active_batch.x = None + active_batch.x_lens = None + active_batch.key_padding_mask = None + active_batch.prefill_attn_mask = None active_batch.prefill_done = True else: + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("decode 阶段缺少 KV cache") + batched_decode_attn_mask = None + if active_batch.decode_attn_mask is not None: + batched_decode_attn_mask = _materialize_decode_mask_for_active_batch(active_batch) + if not batched_decode_attn_mask.any().item(): + batched_decode_attn_mask = None xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( active_batch.xy_pos, active_batch.k_cache, active_batch.v_cache, - active_batch.decode_attn_mask, + batched_decode_attn_mask, ) - if active_batch.decode_attn_mask is not None: - active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False) + active_batch.decode_attn_mask = _advance_decode_mask(active_batch.decode_attn_mask, active_batch.kv_lens) logits = model.ar_predict_layer(xy_dec[:, -1]) - if active_batch.step_idx < 11: - logits = logits[:, :-1] finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) if len(keep_indices) == 0: @@ -528,16 +664,32 @@ def decode_one_step( active_batch.states = [active_batch.states[i] for i in keep_indices] active_batch.y_sequences = updated_sequences active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + next_step_indices = torch.index_select(active_batch.step_indices, dim=0, index=keep_tensor) + next_kv_lens = None if active_batch.kv_lens is None else torch.index_select(active_batch.kv_lens, dim=0, index=keep_tensor) + active_batch.step_indices = next_step_indices + 1 + if not was_prefill: + if next_kv_lens is not None: + active_batch.kv_lens = next_kv_lens + 1 + else: + active_batch.kv_lens = next_kv_lens if active_batch.decode_attn_mask is not None: active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if not active_batch.decode_attn_mask.any().item(): + active_batch.decode_attn_mask = None if active_batch.k_cache is not None and active_batch.v_cache is not None: for cache_index in range(len(active_batch.k_cache)): active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + if active_batch.kv_lens is not None: + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens( + active_batch.decode_attn_mask, + active_batch.kv_lens, + ) active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) - active_batch.step_idx += 1 return active_batch, finished_items @@ -583,6 +735,126 @@ def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> ) +def _materialize_decode_mask_for_active_batch( + active_batch: T2SActiveBatch, + target_mask_len: Optional[int] = None, +) -> torch.Tensor: + if active_batch.k_cache is None or active_batch.kv_lens is None: + raise ValueError("active batch 缺少 KV cache 或 kv_lens") + current_mask_len = active_batch.k_cache[0].shape[1] + 1 + if target_mask_len is None: + target_mask_len = current_mask_len + if active_batch.decode_attn_mask is None: + mask = torch.zeros( + (len(active_batch.request_ids), 1, 1, current_mask_len), + dtype=torch.bool, + device=active_batch.k_cache[0].device, + ) + else: + rows: List[torch.Tensor] = [] + for batch_index, kv_len in enumerate(active_batch.kv_lens.tolist()): + row_len = kv_len + 1 + row_mask = _fit_decode_mask_length( + active_batch.decode_attn_mask[batch_index : batch_index + 1], + row_len, + ) + rows.append(_pad_decode_mask_left(row_mask, target_mask_len)) + mask = torch.cat(rows, dim=0) + if target_mask_len != current_mask_len and active_batch.decode_attn_mask is None: + mask = _pad_decode_mask_left(mask, target_mask_len) + return mask + + +@torch.inference_mode() +def run_prefill_active_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not states: + return None, [] + active_batch = build_prefill_batch(model, states) + return decode_one_step(model, active_batch, max_steps=max_steps) + + +@torch.inference_mode() +def merge_active_batches( + model: Any, + left_batch: Optional[T2SActiveBatch], + right_batch: Optional[T2SActiveBatch], +) -> Optional[T2SActiveBatch]: + if left_batch is None: + return right_batch + if right_batch is None: + return left_batch + if not left_batch.prefill_done or not right_batch.prefill_done: + raise ValueError("只有 prefill 完成后的 active batch 才能 merge") + if left_batch.k_cache is None or left_batch.v_cache is None or right_batch.k_cache is None or right_batch.v_cache is None: + raise ValueError("merge active batch 时缺少 KV cache") + + left_kv_len = int(left_batch.k_cache[0].shape[1]) + right_kv_len = int(right_batch.k_cache[0].shape[1]) + merged_kv_len = max(left_kv_len, right_kv_len) + merged_mask_len = merged_kv_len + 1 + + merged_k_cache: List[torch.Tensor] = [] + merged_v_cache: List[torch.Tensor] = [] + for layer_index in range(len(left_batch.k_cache)): + merged_k_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.k_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.k_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + merged_v_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.v_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.v_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + + merged_decode_attn_mask = torch.cat( + [ + _materialize_decode_mask_for_active_batch(left_batch, merged_mask_len), + _materialize_decode_mask_for_active_batch(right_batch, merged_mask_len), + ], + dim=0, + ) + merged_request_ids = list(left_batch.request_ids) + list(right_batch.request_ids) + merged_states = list(left_batch.states) + list(right_batch.states) + merged_y_sequences = list(left_batch.y_sequences) + list(right_batch.y_sequences) + merged_prefix_lens = torch.cat([left_batch.prefix_lens, right_batch.prefix_lens], dim=0) + if left_batch.kv_lens is None or right_batch.kv_lens is None: + raise ValueError("merge active batch 时缺少 kv_lens") + merged_kv_lens = torch.cat([left_batch.kv_lens, right_batch.kv_lens], dim=0) + merged_decode_attn_mask = _compact_decode_mask_to_kv_lens(merged_decode_attn_mask, merged_kv_lens) + merged_step_indices = torch.cat([left_batch.step_indices, right_batch.step_indices], dim=0) + + return T2SActiveBatch( + request_ids=merged_request_ids, + states=merged_states, + x=None, + x_lens=None, + y_sequences=merged_y_sequences, + prefix_lens=merged_prefix_lens, + xy_pos=build_next_xy_pos(model, merged_y_sequences), + key_padding_mask=None, + prefill_attn_mask=None, + decode_attn_mask=merged_decode_attn_mask, + k_cache=merged_k_cache, + v_cache=merged_v_cache, + kv_lens=merged_kv_lens, + step_indices=merged_step_indices, + prefill_done=True, + ) + + @torch.inference_mode() def run_prefill_step( model: Any, @@ -804,29 +1076,24 @@ def run_scheduler_continuous( max_steps: int, ) -> List[T2SFinishedItem]: pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) - running_requests: List[T2SRunningRequest] = [] + active_batch: Optional[T2SActiveBatch] = None finished: List[T2SFinishedItem] = [] current_tick = 0 - while pending or running_requests: + while pending or active_batch is not None: admitted: List[T2SRequestState] = [] while pending and pending[0].ready_step <= current_tick: admitted.append(pending.pop(0)) - admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps) + admitted_active_batch, admitted_finished = run_prefill_active_batch(model, admitted, max_steps=max_steps) finished.extend(admitted_finished) + active_batch = merge_active_batches(model, active_batch, admitted_active_batch) - if running_requests: - running_requests, step_finished = run_decode_step_for_running( - model, - running_requests, - max_steps=max_steps, - ) + if active_batch is not None: + active_batch, step_finished = decode_one_step(model, active_batch, max_steps=max_steps) finished.extend(step_finished) - running_requests.extend(admitted_running) - - if not running_requests and pending: + if active_batch is None and pending: current_tick = max(current_tick + 1, pending[0].ready_step) continue diff --git a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py new file mode 100644 index 00000000..e2398251 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py @@ -0,0 +1,100 @@ +import os +import re +import sys +from typing import Dict, List, Optional, Tuple + +now_dir = os.getcwd() +sys.path.append(now_dir) + +from text.LangSegmenter import LangSegmenter +from text import cleaned_text_to_sequence +from text.cleaner import clean_text + + +PreparedTextSegmentPayload = Dict[str, object] + + +def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]: + textlist: List[str] = [] + langlist: List[str] = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist + + +def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]: + normalized_language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, normalized_language, version) + phones = cleaned_text_to_sequence(phones, version) + return list(phones), None if word2ph is None else list(word2ph), str(norm_text) + + +def preprocess_text_segments_payload( + text: str, + language: str, + version: str, + final: bool = False, +) -> List[PreparedTextSegmentPayload]: + text = re.sub(r" {2,}", " ", text) + textlist, langlist = split_text_by_language(text, language) + payloads: List[PreparedTextSegmentPayload] = [] + total_phones_len = 0 + for segment_text, segment_lang in zip(textlist, langlist): + phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + payloads.append( + { + "language": segment_lang.replace("all_", ""), + "phones": phones, + "word2ph": word2ph, + "norm_text": norm_text, + } + ) + total_phones_len += len(phones) + + if not final and total_phones_len < 6: + return preprocess_text_segments_payload("." + text, language, version, final=True) + + return payloads diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 348ddb3f..c6d147cf 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -2,6 +2,7 @@ import warnings warnings.filterwarnings("ignore") import math +from typing import List import torch from torch import nn @@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + @torch.no_grad() + def decode_batched_request_local( + self, + codes: torch.Tensor, + code_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + refer_list: List[torch.Tensor], + noise_scale: float = 0.5, + speed: float = 1, + sv_emb: torch.Tensor | None = None, + ): + batch_size = int(codes.size(1)) + if batch_size <= 0: + raise ValueError("decode_batched_request_local 收到空 batch") + if len(refer_list) != batch_size: + raise ValueError("refer_list 数量与 batch size 不一致") + + refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device) + max_refer_len = int(refer_lengths.max().item()) + refer_batch = torch.zeros( + (batch_size, int(refer_list[0].size(1)), max_refer_len), + dtype=refer_list[0].dtype, + device=codes.device, + ) + for batch_index, refer in enumerate(refer_list): + refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype) + if self.version == "v1": + ge = self.ref_enc(refer_batch * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + if sv_emb is None: + raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb") + ge = ge + self.sv_emb(sv_emb).unsqueeze(-1) + ge = self.prelu(ge) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") + y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2 + text_lengths = text_lengths.to(device=text.device, dtype=torch.long) + x, m_p, logs_p, y_mask, _, _ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=ge, reverse=True) + audio = self.dec((z * y_mask)[:, :, :], g=ge) + upsample_factor = 1 + for up_layer in self.dec.ups: + stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride) + upsample_factor *= int(stride) + audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor) + return audio, audio_lengths + @torch.no_grad() def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): diff --git a/api_v3.py b/api_v3.py index 92f9a3b9..74bc7ac8 100644 --- a/api_v3.py +++ b/api_v3.py @@ -107,6 +107,7 @@ import sys import time import traceback import uuid +from collections import deque from dataclasses import dataclass from pathlib import Path from typing import Generator, List, Union @@ -115,6 +116,10 @@ now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + import argparse import subprocess import wave @@ -128,14 +133,15 @@ import uvicorn from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( SchedulerRequestSpec, + T2SActiveBatch, T2SFinishedItem, - T2SRunningRequest, T2SRequestState, - prepare_request_state, - run_decode_step_for_running, - run_prefill_step, + merge_active_batches, + decode_one_step, + run_prefill_active_batch, run_scheduler_continuous, ) from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names @@ -238,39 +244,71 @@ class SchedulerPendingJob: request_id: str state: T2SRequestState done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None enqueue_time: float speed_factor: float sample_steps: int media_type: str - prepare_ms: float = 0.0 prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 first_schedule_time: float | None = None prefill_ms: float = 0.0 + merge_ms: float = 0.0 decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 synth_ms: float = 0.0 pack_ms: float = 0.0 decode_steps: int = 0 + result_ready_time: float | None = None result: dict | None = None sample_rate: int | None = None audio_data: np.ndarray | None = None error: str | None = None +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + class SchedulerDebugWorker: def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): self.tts = tts self.max_steps = max_steps self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 + self.prepare_coordinator = PrepareCoordinator(tts) self.condition = threading.Condition() self.prepare_inflight = 0 self.prepare_peak_inflight = 0 + self.finalize_condition = threading.Condition() + self.finalize_pending_tasks: deque[SchedulerFinalizeTask] = deque() + self.finalize_pending_peak = 0 + self.finalize_inflight = 0 + self.finalize_inflight_peak = 0 + self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) self.pending_jobs: List[SchedulerPendingJob] = [] - self.running_requests: List[T2SRunningRequest] = [] + self.active_batch: T2SActiveBatch | None = None self.job_map: dict[str, SchedulerPendingJob] = {} self.total_finished = 0 self.total_submitted = 0 self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) self.worker_thread.start() + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"t2s-scheduler-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_workers) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() def _sync_device(self) -> None: try: @@ -283,20 +321,7 @@ class SchedulerDebugWorker: pass def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - with self.condition: - self.prepare_inflight += 1 - prepare_inflight_on_enter = self.prepare_inflight - if self.prepare_inflight > self.prepare_peak_inflight: - self.prepare_peak_inflight = self.prepare_inflight - prepare_peak_inflight = self.prepare_peak_inflight - try: - state = prepare_request_state(self.tts, spec) - state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter) - state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight) - return state - finally: - with self.condition: - self.prepare_inflight = max(0, self.prepare_inflight - 1) + raise RuntimeError("prepare_state sync path has been replaced by PrepareCoordinator") def submit( self, @@ -304,27 +329,47 @@ class SchedulerDebugWorker: speed_factor: float, sample_steps: int, media_type: str, - prepare_ms: float, prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, ) -> SchedulerPendingJob: job = SchedulerPendingJob( request_id=state.request_id, state=state, done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, enqueue_time=time.perf_counter(), speed_factor=float(speed_factor), sample_steps=int(sample_steps), media_type=media_type, - prepare_ms=float(prepare_ms), prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), ) with self.condition: self.pending_jobs.append(job) self.job_map[job.request_id] = job self.total_submitted += 1 self.condition.notify_all() + with self.finalize_condition: + self.finalize_condition.notify_all() return job + async def prepare_state_async(self, spec: SchedulerRequestSpec) -> T2SRequestState: + state, _, _ = await self.prepare_coordinator.prepare_state_profiled_async(spec, time.perf_counter()) + return state + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await asyncio.gather(*[self.prepare_state_async(spec) for spec in specs]) + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: with self.condition: for job in jobs: @@ -340,6 +385,14 @@ class SchedulerDebugWorker: if tracked_job is not None: tracked_job.prefill_ms += elapsed_ms + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.merge_ms += elapsed_ms + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: elapsed_ms = elapsed_s * 1000.0 with self.condition: @@ -349,16 +402,30 @@ class SchedulerDebugWorker: job.decode_ms += elapsed_ms job.decode_steps += 1 + def _add_finalize_wait_ms(self, request_ids: List[str], elapsed_ms: float) -> None: + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.finalize_wait_ms += elapsed_ms + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) - phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device) + semantic_tokens = item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) + phones = job.state.phones.detach().clone().unsqueeze(0).to(self.tts.configs.device) + prompt_semantic = job.state.prompt_semantic.detach().clone() + prompt_phones = job.state.prompt_phones.detach().clone() + refer_spec = ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + raw_audio = job.state.raw_audio.detach().clone() audio_fragment = self.tts.synthesize_audio_request_local( semantic_tokens=semantic_tokens, phones=phones, - prompt_semantic=job.state.prompt_semantic, - prompt_phones=job.state.prompt_phones, - refer_spec=job.state.refer_spec, - raw_audio=job.state.raw_audio, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_spec=refer_spec, + raw_audio=raw_audio, raw_sr=job.state.raw_sr, speed=float(job.speed_factor), sample_steps=int(job.sample_steps), @@ -375,6 +442,11 @@ class SchedulerDebugWorker: ) def get_state(self) -> dict: + with self.finalize_condition: + finalize_pending = len(self.finalize_pending_tasks) + finalize_pending_peak = self.finalize_pending_peak + finalize_inflight = self.finalize_inflight + finalize_inflight_peak = self.finalize_inflight_peak with self.condition: bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() @@ -388,12 +460,24 @@ class SchedulerDebugWorker: if self.tts.prepare_ref_semantic_batch_worker is None else self.tts.prepare_ref_semantic_batch_worker.snapshot() ) + prepare_coordinator_state = self.prepare_coordinator.snapshot() return { "pending_jobs": len(self.pending_jobs), - "running_requests": len(self.running_requests), - "prepare_inflight": self.prepare_inflight, - "prepare_peak_inflight": self.prepare_peak_inflight, + "running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids), + "prepare_inflight": prepare_coordinator_state["inflight"], + "prepare_peak_inflight": prepare_coordinator_state["peak_inflight"], + "finalize_pending": finalize_pending, + "finalize_pending_peak": finalize_pending_peak, + "finalize_inflight": finalize_inflight, + "finalize_inflight_peak": finalize_inflight_peak, + "finalize_workers": self.finalize_workers, + "finalize_mode": self.finalize_mode, + "finalize_batch_max_items": self.finalize_batch_max_items, + "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, + "prepare_request_executor_workers": 0, "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), + "prepare_text_feature_workers": int(prepare_coordinator_state["text_feature_workers"]), + "prepare_ref_audio_workers": int(prepare_coordinator_state["ref_audio_workers"]), "prepare_bert_stage": bert_stage, "prepare_bert_batch_worker": bert_batch_worker, "prepare_ref_audio_stage": ref_audio_stage, @@ -405,59 +489,217 @@ class SchedulerDebugWorker: "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), } - def _finalize_finished(self, items: List[T2SFinishedItem]) -> None: + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: if not items: return - jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + tasks: List[SchedulerFinalizeTask] = [] + enqueued_time = time.perf_counter() with self.condition: for item in items: job = self.job_map.get(item.request_id) if job is not None: - jobs_to_finalize.append((job, item)) + tasks.append( + SchedulerFinalizeTask( + request_id=item.request_id, + item=item, + enqueued_time=enqueued_time, + ) + ) + if not tasks: + return + with self.finalize_condition: + self.finalize_pending_tasks.extend(tasks) + if len(self.finalize_pending_tasks) > self.finalize_pending_peak: + self.finalize_pending_peak = len(self.finalize_pending_tasks) + self.finalize_condition.notify_all() - for job, item in jobs_to_finalize: + @staticmethod + def _finalize_batch_key(job: SchedulerPendingJob) -> tuple[float, int]: + return (round(float(job.speed_factor), 6), int(job.sample_steps)) + + def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: + with self.finalize_condition: + while not self.finalize_pending_tasks: + self.finalize_condition.wait() + if self.finalize_mode == "after_t2s_drain": + while not self._is_t2s_drained(): + self.finalize_condition.wait(timeout=self.micro_batch_wait_s) + task = self.finalize_pending_tasks.popleft() + selected_tasks = [task] + batch_key = None + with self.condition: + first_job = self.job_map.get(task.request_id) + if first_job is not None: + batch_key = self._finalize_batch_key(first_job) + batch_deadline = time.perf_counter() + self.finalize_batch_wait_s + while len(selected_tasks) < self.finalize_batch_max_items: + if batch_key is None: + break + matched_index = None + for pending_index, pending_task in enumerate(self.finalize_pending_tasks): + with self.condition: + pending_job = self.job_map.get(pending_task.request_id) + if pending_job is None: + matched_index = pending_index + break + if self._finalize_batch_key(pending_job) == batch_key: + matched_index = pending_index + break + if matched_index is not None: + selected_tasks.append(self.finalize_pending_tasks[matched_index]) + del self.finalize_pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.finalize_condition.wait(timeout=remaining) + self.finalize_inflight += len(selected_tasks) + if self.finalize_inflight > self.finalize_inflight_peak: + self.finalize_inflight_peak = self.finalize_inflight + return selected_tasks + + def _finalize_task_done(self, count: int) -> None: + with self.finalize_condition: + self.finalize_inflight = max(0, self.finalize_inflight - count) + + def _is_t2s_drained(self) -> bool: + with self.condition: + return ( + self.active_batch is None + and not self.pending_jobs + and self.prepare_inflight <= 0 + ) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + return + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self._notify_done_future(job) + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def _run_finalize_loop(self) -> None: + while True: + tasks = self._take_finalize_task_batch() try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + finalize_wait_request_ids: List[str] = [] + with self.condition: + for task in tasks: + job = self.job_map.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + finalize_wait_request_ids.append(task.request_id) + if not jobs_and_items: + continue + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) self._sync_device() synth_start = time.perf_counter() - sample_rate, audio_data = self._synthesize_finished_audio(job, item) + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) self._sync_device() synth_ms = (time.perf_counter() - synth_start) * 1000.0 + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) except Exception as exc: - self._finalize_error([item.request_id], str(exc)) - continue - - finished_at = time.perf_counter() - with self.condition: - if self.job_map.get(item.request_id) is not job: - continue - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - job.synth_ms += synth_ms - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - prepare_profile = dict(job.state.prepare_profile) - job.result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "prepare_ms": job.prepare_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "decode_ms": job.decode_ms, - "synth_ms": job.synth_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.done_event.set() - self.job_map.pop(item.request_id, None) - self.total_finished += 1 + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self._finalize_task_done(len(tasks)) def _finalize_error(self, request_ids: List[str], error: str) -> None: if not request_ids: @@ -469,12 +711,28 @@ class SchedulerDebugWorker: continue job.error = error job.done_event.set() + self._notify_done_future(job) self.job_map.pop(request_id, None) self.total_finished += 1 + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(True) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: with self.condition: - if not self.pending_jobs and not self.running_requests: + if not self.pending_jobs and self.active_batch is None: self.condition.wait(timeout=self.micro_batch_wait_s) elif wait_for_batch and self.pending_jobs: self.condition.wait(timeout=self.micro_batch_wait_s) @@ -482,11 +740,13 @@ class SchedulerDebugWorker: return [] pending = list(self.pending_jobs) self.pending_jobs.clear() + with self.finalize_condition: + self.finalize_condition.notify_all() return pending def _run_loop(self) -> None: while True: - wait_for_batch = len(self.running_requests) == 0 + wait_for_batch = self.active_batch is None pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) if pending_jobs: @@ -494,37 +754,54 @@ class SchedulerDebugWorker: self._sync_device() prefill_start = time.perf_counter() self._mark_prefill_started(pending_jobs, prefill_start) - admitted_running, admitted_finished = run_prefill_step( + admitted_active_batch, admitted_finished = run_prefill_active_batch( self.tts.t2s_model.model, [job.state for job in pending_jobs], max_steps=self.max_steps, ) self._sync_device() self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) - self._finalize_finished(admitted_finished) - self.running_requests.extend(admitted_running) + self._enqueue_finalize_finished(admitted_finished) + merge_start = time.perf_counter() + self.active_batch = merge_active_batches( + self.tts.t2s_model.model, + self.active_batch, + admitted_active_batch, + ) + self._add_merge_time( + [] if self.active_batch is None else list(self.active_batch.request_ids), + time.perf_counter() - merge_start, + ) + with self.finalize_condition: + self.finalize_condition.notify_all() except Exception as exc: self._finalize_error([job.request_id for job in pending_jobs], str(exc)) - if self.running_requests: + if self.active_batch is not None: try: - active_request_ids = [item.state.request_id for item in self.running_requests] + active_request_ids = [state.request_id for state in self.active_batch.states] self._sync_device() decode_start = time.perf_counter() - self.running_requests, step_finished = run_decode_step_for_running( + self.active_batch, step_finished = decode_one_step( self.tts.t2s_model.model, - self.running_requests, + self.active_batch, max_steps=self.max_steps, ) self._sync_device() self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) - self._finalize_finished(step_finished) + self._enqueue_finalize_finished(step_finished) + with self.finalize_condition: + self.finalize_condition.notify_all() except Exception as exc: self._finalize_error(active_request_ids, str(exc)) - self.running_requests = [] + self.active_batch = None + with self.finalize_condition: + self.finalize_condition.notify_all() continue if not pending_jobs: + with self.finalize_condition: + self.finalize_condition.notify_all() time.sleep(self.micro_batch_wait_s) @@ -788,10 +1065,6 @@ def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: ] -def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return [scheduler_debug_worker.prepare_state(spec) for spec in specs] - - def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: payload = request.dict() request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" @@ -845,7 +1118,7 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): try: set_scheduler_seed(request.seed) specs = build_scheduler_request_specs(request.requests) - states = await asyncio.to_thread(prepare_scheduler_states_batch, specs) + states = await scheduler_debug_worker.prepare_states_batch_async(specs) finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) return JSONResponse( status_code=200, @@ -867,20 +1140,51 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): try: request_start = time.perf_counter() + prepare_start = request_start spec = build_scheduler_submit_spec(request) - prepare_start = time.perf_counter() - state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec) - prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0 - prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async( + spec, + spec_ready_at, + ) + prepare_end = time.perf_counter() + prepare_wall_ms = (prepare_end - prepare_start) * 1000.0 + prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms)) + prepare_executor_queue_ms = float( + state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)) + ) + prepare_executor_run_ms = float( + state.prepare_profile.get( + "executor_run_wall_ms", + max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0), + ) + ) + prepare_other_ms = max( + 0.0, + prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() job = scheduler_debug_worker.submit( state, speed_factor=float(request.speed_factor), sample_steps=int(request.sample_steps), media_type=request.media_type, - prepare_ms=prepare_ms, prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, ) - timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec)) + api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0) + timeout_ok = False + try: + await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec)) + timeout_ok = True + except asyncio.TimeoutError: + timeout_ok = False + wait_return_at = time.perf_counter() if not timeout_ok: return JSONResponse( status_code=202, @@ -888,8 +1192,10 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "message": "queued", "request_id": job.request_id, "timings": { - "prepare_ms": prepare_ms, + "prepare_ms": prepare_wall_ms, "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "api_after_prepare_ms": api_after_prepare_ms, "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), }, "worker_state": scheduler_debug_worker.get_state(), @@ -911,9 +1217,13 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): ) pack_start = time.perf_counter() audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_ms = (time.perf_counter() - pack_start) * 1000.0 + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 job.pack_ms = pack_ms - request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 headers = { "X-Request-Id": job.request_id, "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", @@ -921,16 +1231,32 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "X-Queue-Wait-Ms": ( f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" ), - "X-Prepare-Ms": f"{prepare_ms:.3f}", + "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", + "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", + "X-Prepare-Admission-Wait-Ms": ( + f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" + if job.result is not None + else "0.000" + ), + "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", + "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", + "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", + "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", + "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", "X-Pack-Ms": f"{pack_ms:.3f}", "X-Worker-Total-Ms": ( f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" ), - "X-Request-Total-Ms": f"{request_total_ms:.3f}", + "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", } if job.result is not None: @@ -939,16 +1265,48 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): { "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str( + int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str( + int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0)) + ), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str( + int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str( + int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0)) + ), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str( + int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0)) + ), + "X-Prepare-Target-Bert-High-Pressure-Peak": str( + int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0)) + ), "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) ), "X-Prepare-Target-Bert-Batch-Size-Peak": str( int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) ), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", @@ -964,13 +1322,22 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", - "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", "X-Prepare-Inflight-On-Enter": str( int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) ), "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), } ) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + ) + headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" + headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" + headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) except Exception as e: return JSONResponse( From 69ac7f90271bc3099fe608cabbf9979077a68a71 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Tue, 10 Mar 2026 06:59:28 +0800 Subject: [PATCH 09/24] Integrate UnifiedTTSEngine into TTS API for improved audio processing and control. Refactor tts_handle and control endpoints to utilize the new engine, enhancing error handling and response management. Update set_refer_audio and set_gpt_weights endpoints to return payloads from the engine, streamlining audio configuration processes. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 36 +- GPT_SoVITS/TTS_infer_pack/unified_engine.py | 1255 +++++++++++++++++++ api_v2.py | 99 +- api_v3.py | 323 +---- 4 files changed, 1332 insertions(+), 381 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index d475b804..c7ae465c 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -468,7 +468,26 @@ class TTS: ) self._init_models() + self.refresh_runtime_components() + self.prompt_cache: dict = { + "ref_audio_path": None, + "prompt_semantic": None, + "refer_spec": [], + "prompt_text": None, + "prompt_lang": None, + "phones": None, + "bert_features": None, + "norm_text": None, + "aux_ref_audio_paths": [], + } + + self.stop_flag: bool = False + self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 + + def refresh_runtime_components(self): + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": self.prepare_bert_batch_worker = PrepareBertBatchWorker( bert_model=self.bert_model, @@ -509,7 +528,7 @@ class TTS: max_batch_samples=int(ref_max_batch_samples), ) - self.text_preprocessor: TextPreprocessor = TextPreprocessor( + self.text_preprocessor = TextPreprocessor( self.bert_model, self.bert_tokenizer, self.configs.device, @@ -517,21 +536,6 @@ class TTS: bert_batch_worker=self.prepare_bert_batch_worker, ) - self.prompt_cache: dict = { - "ref_audio_path": None, - "prompt_semantic": None, - "refer_spec": [], - "prompt_text": None, - "prompt_lang": None, - "phones": None, - "bert_features": None, - "norm_text": None, - "aux_ref_audio_paths": [], - } - - self.stop_flag: bool = False - self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 - def _init_models( self, ): diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py new file mode 100644 index 00000000..0a95015f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -0,0 +1,1255 @@ +from __future__ import annotations + +import asyncio +import os +import signal +import subprocess +import sys +import threading +import time +import uuid +import wave +from collections import deque +from dataclasses import dataclass, field +from io import BytesIO +from pathlib import Path +from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Sequence, Tuple, Union + +import numpy as np +import soundfile as sf +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + SchedulerRequestSpec, + T2SActiveBatch, + T2SFinishedItem, + T2SRequestState, + decode_one_step, + merge_active_batches, + run_prefill_active_batch, + run_scheduler_continuous, +) + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +class UnifiedSchedulerWorker: + def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): + self.tts = tts + self.max_steps = int(max_steps) + self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.prepare_coordinator = PrepareCoordinator(tts) + self.condition = threading.Condition() + self.prepare_inflight = 0 + self.prepare_peak_inflight = 0 + self.finalize_condition = threading.Condition() + self.finalize_pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.finalize_pending_peak = 0 + self.finalize_inflight = 0 + self.finalize_inflight_peak = 0 + self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + self.job_map: Dict[str, SchedulerPendingJob] = {} + self.total_finished = 0 + self.total_submitted = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) + self.worker_thread.start() + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"unified-t2s-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_workers) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() + + def snapshot(self) -> dict: + with self.condition: + finalize_pending = len(self.finalize_pending_tasks) + prepare_state = self.prepare_coordinator.snapshot() + return { + "pending_jobs": len(self.pending_jobs), + "running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "finalize_pending": finalize_pending, + "finalize_pending_peak": self.finalize_pending_peak, + "finalize_inflight": self.finalize_inflight, + "finalize_inflight_peak": self.finalize_inflight_peak, + "finalize_workers": self.finalize_workers, + "finalize_mode": self.finalize_mode, + "finalize_batch_max_items": self.finalize_batch_max_items, + "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + with self.finalize_condition: + return ( + self.active_batch is None + and not self.pending_jobs + and not self.job_map + and self.prepare_coordinator.snapshot()["inflight"] <= 0 + and self.finalize_inflight <= 0 + and not self.finalize_pending_tasks + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + ) -> SchedulerPendingJob: + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + ) + with self.condition: + self.pending_jobs.append(job) + self.job_map[job.request_id] = job + self.total_submitted += 1 + self.condition.notify_all() + with self.finalize_condition: + self.finalize_condition.notify_all() + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + with self.condition: + self.prepare_inflight += 1 + self.prepare_peak_inflight = max(self.prepare_peak_inflight, self.prepare_inflight) + try: + return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + with self.condition: + self.prepare_inflight = max(0, self.prepare_inflight - 1) + self.condition.notify_all() + with self.finalize_condition: + self.finalize_condition.notify_all() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is None: + continue + tracked_job.first_schedule_time = float(started_at) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.decode_ms += delta_ms + job.decode_steps += 1 + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + with self.finalize_condition: + for item in items: + self.finalize_pending_tasks.append( + SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) + ) + self.finalize_pending_peak = max(self.finalize_pending_peak, len(self.finalize_pending_tasks)) + self.finalize_condition.notify_all() + + def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: + with self.finalize_condition: + while not self.finalize_pending_tasks: + self.finalize_condition.wait() + selected_tasks = [self.finalize_pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.finalize_inflight += len(selected_tasks) + self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) + return selected_tasks + batch_deadline = time.perf_counter() + self.finalize_batch_wait_s + while len(selected_tasks) < self.finalize_batch_max_items: + if not self.finalize_pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.finalize_condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.finalize_pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.finalize_pending_tasks[matched_index]) + del self.finalize_pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.finalize_condition.wait(timeout=remaining) + self.finalize_inflight += len(selected_tasks) + self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) + return selected_tasks + + def _finalize_task_done(self, count: int) -> None: + with self.finalize_condition: + self.finalize_inflight = max(0, self.finalize_inflight - count) + self.finalize_condition.notify_all() + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ), + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + return + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self._notify_done_future(job) + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + self.condition.notify_all() + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is None: + continue + job.error = error + job.done_event.set() + self._notify_done_future(job) + self.job_map.pop(request_id, None) + self.total_finished += 1 + self.condition.notify_all() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(True) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def _run_finalize_loop(self) -> None: + while True: + tasks = self._take_finalize_task_batch() + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_map.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + continue + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self._finalize_task_done(len(tasks)) + + def _run_loop(self) -> None: + while True: + wait_for_batch = self.active_batch is None + pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) + + if pending_jobs: + try: + self._sync_device() + prefill_start = time.perf_counter() + self._mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + self._add_prefill_time([job.request_id for job in pending_jobs], time.perf_counter() - prefill_start) + self._enqueue_finalize_finished(admitted_finished) + merge_start = time.perf_counter() + self.active_batch = merge_active_batches( + self.tts.t2s_model.model, + self.active_batch, + admitted_active_batch, + ) + self._add_merge_time( + [] if self.active_batch is None else list(self.active_batch.request_ids), + time.perf_counter() - merge_start, + ) + except Exception as exc: + self._finalize_error([job.request_id for job in pending_jobs], str(exc)) + + if self.active_batch is not None: + active_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in self.active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + self.active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + self.active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) + self._enqueue_finalize_finished(step_finished) + except Exception as exc: + self._finalize_error(active_request_ids, str(exc)) + self.active_batch = None + continue + + if not pending_jobs: + time.sleep(self.micro_batch_wait_s) + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except (RuntimeError, ValueError): + handle_pack_ogg() + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", + "-ar", + str(rate), + "-ac", + "1", + "-i", + "pipe:0", + "-c:a", + "aac", + "-b:a", + "192k", + "-vn", + "-f", + "adts", + "pipe:1", + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + wav_buf.seek(0) + return wav_buf.read() + + +class UnifiedTTSEngine: + def __init__( + self, + tts: TTS, + cut_method_names: Sequence[str], + control_callbacks: RuntimeControlCallbacks | None = None, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + ) -> None: + self.tts = tts + self.cut_method_names = set(cut_method_names) + self.control_callbacks = control_callbacks or RuntimeControlCallbacks() + self.reference_registry = ReferenceRegistry() + self.model_registry = ModelRegistry( + t2s_weights_path=str(self.tts.configs.t2s_weights_path), + vits_weights_path=str(self.tts.configs.vits_weights_path), + ) + self.scheduler_worker = UnifiedSchedulerWorker(tts, max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) + self.direct_tts_lock = threading.RLock() + self.management_lock = threading.RLock() + + def _normalize_lang(self, value: str | None) -> str | None: + if value in [None, ""]: + return value + return str(value).lower() + + def _apply_default_reference(self, req: dict) -> dict: + normalized = dict(req) + default_ref = self.reference_registry.get_default() + if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: + normalized["ref_audio_path"] = default_ref.ref_audio_path + if "text_lang" in normalized: + normalized["text_lang"] = self._normalize_lang(normalized.get("text_lang")) + if "prompt_lang" in normalized: + normalized["prompt_lang"] = self._normalize_lang(normalized.get("prompt_lang")) + return normalized + + def check_params(self, req: dict) -> Optional[str]: + text = req.get("text", "") + text_lang = req.get("text_lang", "") + ref_audio_path = req.get("ref_audio_path", "") + media_type = req.get("media_type", "wav") + prompt_lang = req.get("prompt_lang", "") + text_split_method = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return "ref_audio_path is required" + if text in [None, ""]: + return "text is required" + if text_lang in [None, ""]: + return "text_lang is required" + if text_lang.lower() not in self.tts.configs.languages: + return f"text_lang: {text_lang} is not supported in version {self.tts.configs.version}" + if prompt_lang in [None, ""]: + return "prompt_lang is required" + if prompt_lang.lower() not in self.tts.configs.languages: + return f"prompt_lang: {prompt_lang} is not supported in version {self.tts.configs.version}" + if media_type not in ["wav", "raw", "ogg", "aac"]: + return f"media_type: {media_type} is not supported" + if text_split_method not in self.cut_method_names: + return f"text_split_method:{text_split_method} is not supported" + return None + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + normalized = dict(req) + streaming_mode = normalized.get("streaming_mode", False) + return_fragment = normalized.get("return_fragment", False) + if streaming_mode is False: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 0: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 1 or streaming_mode is True: + normalized["streaming_mode"] = False + normalized["return_fragment"] = True + normalized["fixed_length_chunk"] = False + elif streaming_mode == 2: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 3: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = True + else: + raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") + normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) + return normalized + + def _iter_direct_tts_bytes(self, req: dict) -> Generator[bytes, None, None]: + media_type = req["media_type"] + with self.direct_tts_lock: + tts_generator = self.tts.run(req) + first_chunk = True + current_media_type = media_type + for sr, chunk in tts_generator: + if first_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + current_media_type = "raw" + first_chunk = False + yield pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_streaming_mode(self._apply_default_reference(req)) + error = self.check_params(normalized) + if error is not None: + raise ValueError(error) + media_type = normalized.get("media_type", "wav") + if normalized["response_streaming"]: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_direct_tts_bytes(normalized), + ) + with self.direct_tts_lock: + tts_generator = self.tts.run(normalized) + sr, audio_data = next(tts_generator) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=pack_audio(BytesIO(), audio_data, sr, media_type).getvalue(), + ) + + def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, payload in enumerate(request_items): + req = self._apply_default_reference( + { + "text": payload["text"], + "text_lang": self._normalize_lang(payload["text_lang"]), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": self._normalize_lang(payload["prompt_lang"]), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + ) + error = self.check_params(req) + if error is not None: + raise ValueError(f"request[{index}] 参数非法: {error}") + specs.append( + SchedulerRequestSpec( + request_id=payload.get("request_id") or f"req_{index:03d}", + ref_audio_path=Path(req["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=req["prompt_lang"], + text=payload["text"], + text_lang=req["text_lang"], + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload.get("early_stop_num", -1)), + ready_step=int(payload.get("ready_step", 0)), + ) + ) + return specs + + def build_scheduler_submit_spec(self, payload: dict) -> SchedulerRequestSpec: + request_id = payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}" + req = self._apply_default_reference( + { + "text": payload["text"], + "text_lang": self._normalize_lang(payload["text_lang"]), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": self._normalize_lang(payload["prompt_lang"]), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": float(payload["speed_factor"]), + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": payload["media_type"], + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": int(payload["sample_steps"]), + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + ) + error = self.check_params(req) + if error is not None: + raise ValueError(f"request 参数非法: {error}") + return SchedulerRequestSpec( + request_id=request_id, + ref_audio_path=Path(req["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=req["prompt_lang"], + text=payload["text"], + text_lang=req["text_lang"], + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload.get("early_stop_num", -1)), + ready_step=0, + ) + + @staticmethod + def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + @staticmethod + def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + set_scheduler_seed(seed) + specs = self.build_scheduler_request_specs(request_items) + states = await self.scheduler_worker.prepare_states_batch_async(specs) + finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) + return SchedulerDebugExecution( + payload={ + "message": "success", + "request_count": len(states), + "max_steps": int(max_steps), + "requests": self.summarize_scheduler_states(states), + "finished": self.summarize_scheduler_finished(finished), + } + ) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + request_start = time.perf_counter() + prepare_start = request_start + spec = self.build_scheduler_submit_spec(payload) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + state, prepare_exec_started_at, prepare_exec_finished_at = await self.scheduler_worker.prepare_state_profiled_async( + spec, + spec_ready_at, + ) + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) + prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) + prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile = dict(state.prepare_profile) + prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) + api_after_prepare_start = time.perf_counter() + loop = asyncio.get_running_loop() + done_future = loop.create_future() + job = self.scheduler_worker.submit( + state=state, + speed_factor=float(payload["speed_factor"]), + sample_steps=int(payload["sample_steps"]), + media_type=str(payload["media_type"]), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + ) + api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) + await asyncio.wait_for(done_future, timeout=float(payload.get("timeout_sec", 30.0))) + wait_return_at = time.perf_counter() + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 + headers = { + "X-Request-Id": job.request_id, + "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", + "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", + "X-Queue-Wait-Ms": f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000", + "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", + "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", + "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", + "X-Prepare-Admission-Wait-Ms": ( + f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" + if job.result is not None + else "0.000" + ), + "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", + "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", + "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", + "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", + "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", + "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", + "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", + "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", + "X-Pack-Ms": f"{pack_ms:.3f}", + "X-Worker-Total-Ms": f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000", + "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", + "X-Decode-Steps": str(int(job.result["decode_steps"])) if job.result is not None else "0", + "X-Sample-Rate": str(int(job.sample_rate)), + } + prepare_profile = job.result.get("prepare_profile", {}) if job.result is not None else {} + if job.result is not None: + headers.update( + { + "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", + "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get('text_cpu_parallel_workers', 0.0))), + "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", + "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", + "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", + "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", + "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get('worker_prepare_inflight_on_enter', 0.0))), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get('worker_prepare_peak_inflight', 0.0))), + } + ) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + ) + headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" + headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" + headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" + return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) + + def get_scheduler_state(self) -> dict: + return self.scheduler_worker.snapshot() + + def get_runtime_state(self) -> dict: + model_state = self.model_registry.snapshot() + default_ref = self.reference_registry.get_default() + scheduler_state = self.get_scheduler_state() + return { + "message": "success", + "default_reference": { + "ref_audio_path": default_ref.ref_audio_path, + "updated_at": default_ref.updated_at, + }, + "model_registry": { + "generation": model_state.generation, + "t2s_generation": model_state.t2s_generation, + "vits_generation": model_state.vits_generation, + "t2s_weights_path": model_state.t2s_weights_path, + "vits_weights_path": model_state.vits_weights_path, + "updated_at": model_state.updated_at, + }, + "worker_state": scheduler_state, + } + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec): + raise TimeoutError("scheduler worker did not drain before model reload") + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + if refer_audio_path in [None, ""]: + state = self.reference_registry.clear() + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + if not os.path.exists(str(refer_audio_path)): + raise FileNotFoundError(f"{refer_audio_path} not exists") + with self.management_lock: + with self.direct_tts_lock: + self.tts.set_ref_audio(str(refer_audio_path)) + state = self.reference_registry.set_default(str(refer_audio_path)) + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + + def set_gpt_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("gpt weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_t2s_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_t2s_reload(str(weights_path)) + return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation} + + def set_sovits_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("sovits weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_vits_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_vits_reload(str(weights_path)) + return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation} + + def handle_control(self, command: str) -> None: + if command == "restart": + if self.control_callbacks.restart is None: + os.execl(sys.executable, sys.executable, *sys.argv) + self.control_callbacks.restart() + return + if command == "exit": + if self.control_callbacks.exit is None: + os.kill(os.getpid(), signal.SIGTERM) + return + self.control_callbacks.exit() + return + raise ValueError(f"unsupported command: {command}") diff --git a/api_v2.py b/api_v2.py index 21511db3..fb17dfd6 100644 --- a/api_v2.py +++ b/api_v2.py @@ -123,6 +123,7 @@ from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine from pydantic import BaseModel import threading @@ -147,6 +148,14 @@ if config_path in [None, ""]: tts_config = TTS_Config(config_path) print(tts_config) tts_pipeline = TTS(tts_config) +tts_engine = UnifiedTTSEngine( + tts_pipeline, + cut_method_names=cut_method_names, + control_callbacks=RuntimeControlCallbacks( + restart=lambda: os.execl(sys.executable, sys.executable, *argv), + exit=lambda: os.kill(os.getpid(), signal.SIGTERM), + ), +) APP = FastAPI() @@ -377,70 +386,11 @@ async def tts_handle(req: dict): StreamingResponse: audio stream response. """ - streaming_mode = req.get("streaming_mode", False) - return_fragment = req.get("return_fragment", False) - media_type = req.get("media_type", "wav") - - check_res = check_params(req) - if check_res is not None: - return check_res - - if streaming_mode == 0: - streaming_mode = False - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 1: - streaming_mode = False - return_fragment = True - fixed_length_chunk = False - elif streaming_mode == 2: - streaming_mode = True - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 3: - streaming_mode = True - return_fragment = False - fixed_length_chunk = True - - else: - return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) - - req["streaming_mode"] = streaming_mode - req["return_fragment"] = return_fragment - req["fixed_length_chunk"] = fixed_length_chunk - - print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") - - streaming_mode = streaming_mode or return_fragment - - try: - tts_generator = tts_pipeline.run(req) - - if streaming_mode: - - def streaming_generator(tts_generator: Generator, media_type: str): - if_frist_chunk = True - for sr, chunk in tts_generator: - if if_frist_chunk and media_type == "wav": - yield wave_header_chunk(sample_rate=sr) - media_type = "raw" - if_frist_chunk = False - yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() - - # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" - return StreamingResponse( - streaming_generator( - tts_generator, - media_type, - ), - media_type=f"audio/{media_type}", - ) - - else: - sr, audio_data = next(tts_generator) - audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - return Response(audio_data, media_type=f"audio/{media_type}") + result = tts_engine.run_direct_tts(req) + if result.streaming: + return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") + return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") except Exception as e: return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) @@ -449,7 +399,11 @@ async def tts_handle(req: dict): async def control(command: str = None): if command is None: return JSONResponse(status_code=400, content={"message": "command is required"}) - handle_control(command) + try: + tts_engine.handle_control(command) + return JSONResponse(status_code=200, content={"message": "success"}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)}) @APP.get("/tts") @@ -517,10 +471,10 @@ async def tts_post_endpoint(request: TTS_Request): @APP.get("/set_refer_audio") async def set_refer_aduio(refer_audio_path: str = None): try: - tts_pipeline.set_ref_audio(refer_audio_path) + payload = tts_engine.set_refer_audio(refer_audio_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) # @APP.post("/set_refer_audio") @@ -545,24 +499,19 @@ async def set_refer_aduio(refer_audio_path: str = None): @APP.get("/set_gpt_weights") async def set_gpt_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) - tts_pipeline.init_t2s_weights(weights_path) + payload = tts_engine.set_gpt_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) - - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) @APP.get("/set_sovits_weights") async def set_sovits_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) - tts_pipeline.init_vits_weights(weights_path) + payload = tts_engine.set_sovits_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) if __name__ == "__main__": diff --git a/api_v3.py b/api_v3.py index 74bc7ac8..37d66977 100644 --- a/api_v3.py +++ b/api_v3.py @@ -144,6 +144,7 @@ from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( run_prefill_active_batch, run_scheduler_continuous, ) +from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel import threading @@ -169,6 +170,14 @@ if config_path in [None, ""]: tts_config = TTS_Config(config_path) print(tts_config) tts_pipeline = TTS(tts_config) +tts_engine = UnifiedTTSEngine( + tts_pipeline, + cut_method_names=cut_method_names, + control_callbacks=RuntimeControlCallbacks( + restart=lambda: os.execl(sys.executable, sys.executable, *argv), + exit=lambda: os.kill(os.getpid(), signal.SIGTERM), + ), +) APP = FastAPI() @@ -805,7 +814,7 @@ class SchedulerDebugWorker: time.sleep(self.micro_batch_wait_s) -scheduler_debug_worker = SchedulerDebugWorker(tts_pipeline) +scheduler_debug_worker = tts_engine.scheduler_worker def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): @@ -1116,20 +1125,12 @@ def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerR async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): try: - set_scheduler_seed(request.seed) - specs = build_scheduler_request_specs(request.requests) - states = await scheduler_debug_worker.prepare_states_batch_async(specs) - finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) - return JSONResponse( - status_code=200, - content={ - "message": "success", - "request_count": len(states), - "max_steps": int(request.max_steps), - "requests": summarize_scheduler_states(states), - "finished": summarize_scheduler_finished(finished), - }, + result = await tts_engine.run_scheduler_debug( + request_items=[item.dict() for item in request.requests], + max_steps=int(request.max_steps), + seed=int(request.seed), ) + return JSONResponse(status_code=200, content=result.payload) except Exception as e: return JSONResponse( status_code=400, @@ -1139,206 +1140,8 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): try: - request_start = time.perf_counter() - prepare_start = request_start - spec = build_scheduler_submit_spec(request) - spec_ready_at = time.perf_counter() - prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) - state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async( - spec, - spec_ready_at, - ) - prepare_end = time.perf_counter() - prepare_wall_ms = (prepare_end - prepare_start) * 1000.0 - prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) - prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms)) - prepare_executor_queue_ms = float( - state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)) - ) - prepare_executor_run_ms = float( - state.prepare_profile.get( - "executor_run_wall_ms", - max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0), - ) - ) - prepare_other_ms = max( - 0.0, - prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - job = scheduler_debug_worker.submit( - state, - speed_factor=float(request.speed_factor), - sample_steps=int(request.sample_steps), - media_type=request.media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - ) - api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0) - timeout_ok = False - try: - await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec)) - timeout_ok = True - except asyncio.TimeoutError: - timeout_ok = False - wait_return_at = time.perf_counter() - if not timeout_ok: - return JSONResponse( - status_code=202, - content={ - "message": "queued", - "request_id": job.request_id, - "timings": { - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "api_after_prepare_ms": api_after_prepare_ms, - "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), - }, - "worker_state": scheduler_debug_worker.get_state(), - }, - ) - if job.error is not None: - return JSONResponse( - status_code=400, - content={"message": "scheduler submit failed", "request_id": job.request_id, "Exception": job.error}, - ) - if job.audio_data is None or job.sample_rate is None: - return JSONResponse( - status_code=500, - content={ - "message": "scheduler submit failed", - "request_id": job.request_id, - "Exception": "job finished without audio payload", - }, - ) - pack_start = time.perf_counter() - audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_end = time.perf_counter() - pack_ms = (pack_end - pack_start) * 1000.0 - job.pack_ms = pack_ms - api_wait_result_ms = 0.0 - if job.result_ready_time is not None: - api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) - worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 - headers = { - "X-Request-Id": job.request_id, - "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", - "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", - "X-Queue-Wait-Ms": ( - f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" - ), - "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", - "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", - "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", - "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", - "X-Prepare-Admission-Wait-Ms": ( - f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" - if job.result is not None - else "0.000" - ), - "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", - "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", - "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", - "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", - "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", - "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", - "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", - "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", - "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", - "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", - "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", - "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", - "X-Pack-Ms": f"{pack_ms:.3f}", - "X-Worker-Total-Ms": ( - f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" - ), - "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", - "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", - } - if job.result is not None: - prepare_profile = job.result.get("prepare_profile", {}) - headers.update( - { - "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str( - int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0)) - ), - "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str( - int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str( - int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0)) - ), - "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str( - int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-High-Pressure-Peak": str( - int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0)) - ), - "X-Prepare-Target-Bert-High-Pressure-Peak": str( - int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( - int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) - ), - "X-Prepare-Target-Bert-Batch-Size-Peak": str( - int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", - "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", - "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), - "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", - "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Batch-Size": str( - int(prepare_profile.get("prompt_semantic_batch_size", 0.0)) - ), - "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", - "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", - "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", - "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", - "X-Prepare-Inflight-On-Enter": str( - int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) - ), - "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), - } - ) - response_ready_at = time.perf_counter() - response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - request_other_ms = max( - 0.0, - request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, - ) - headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" - headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" - headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" - return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) + result = await tts_engine.run_scheduler_submit(request.dict()) + return Response(result.audio_bytes, media_type=result.media_type, headers=result.headers) except Exception as e: return JSONResponse( status_code=400, @@ -1381,70 +1184,11 @@ async def tts_handle(req: dict): StreamingResponse: audio stream response. """ - streaming_mode = req.get("streaming_mode", False) - return_fragment = req.get("return_fragment", False) - media_type = req.get("media_type", "wav") - - check_res = check_params(req) - if check_res is not None: - return check_res - - if streaming_mode == 0: - streaming_mode = False - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 1: - streaming_mode = False - return_fragment = True - fixed_length_chunk = False - elif streaming_mode == 2: - streaming_mode = True - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 3: - streaming_mode = True - return_fragment = False - fixed_length_chunk = True - - else: - return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) - - req["streaming_mode"] = streaming_mode - req["return_fragment"] = return_fragment - req["fixed_length_chunk"] = fixed_length_chunk - - print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") - - streaming_mode = streaming_mode or return_fragment - - try: - tts_generator = tts_pipeline.run(req) - - if streaming_mode: - - def streaming_generator(tts_generator: Generator, media_type: str): - if_frist_chunk = True - for sr, chunk in tts_generator: - if if_frist_chunk and media_type == "wav": - yield wave_header_chunk(sample_rate=sr) - media_type = "raw" - if_frist_chunk = False - yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() - - # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" - return StreamingResponse( - streaming_generator( - tts_generator, - media_type, - ), - media_type=f"audio/{media_type}", - ) - - else: - sr, audio_data = next(tts_generator) - audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - return Response(audio_data, media_type=f"audio/{media_type}") + result = tts_engine.run_direct_tts(req) + if result.streaming: + return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") + return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") except Exception as e: return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) @@ -1453,7 +1197,11 @@ async def tts_handle(req: dict): async def control(command: str = None): if command is None: return JSONResponse(status_code=400, content={"message": "command is required"}) - handle_control(command) + try: + tts_engine.handle_control(command) + return JSONResponse(status_code=200, content={"message": "success"}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)}) @APP.get("/tts") @@ -1530,16 +1278,16 @@ async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request): @APP.get("/tts_scheduler_state") async def tts_scheduler_state_endpoint(): - return JSONResponse(status_code=200, content={"message": "success", "worker_state": scheduler_debug_worker.get_state()}) + return JSONResponse(status_code=200, content=tts_engine.get_runtime_state()) @APP.get("/set_refer_audio") async def set_refer_aduio(refer_audio_path: str = None): try: - tts_pipeline.set_ref_audio(refer_audio_path) + payload = tts_engine.set_refer_audio(refer_audio_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) # @APP.post("/set_refer_audio") @@ -1564,24 +1312,19 @@ async def set_refer_aduio(refer_audio_path: str = None): @APP.get("/set_gpt_weights") async def set_gpt_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) - tts_pipeline.init_t2s_weights(weights_path) + payload = tts_engine.set_gpt_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) - - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) @APP.get("/set_sovits_weights") async def set_sovits_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) - tts_pipeline.init_vits_weights(weights_path) + payload = tts_engine.set_sovits_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) if __name__ == "__main__": From d1a97fd04d79014d2f5e8fce40de3c6470d6e289 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Tue, 10 Mar 2026 20:46:14 +0800 Subject: [PATCH 10/24] Refactor TTS API to streamline audio processing by removing unused packing functions and optimizing the tts_handle method for asynchronous execution. Update type hints and clean up imports for improved code clarity and maintainability. --- api_v2.py | 176 +---------- api_v3.py | 905 +----------------------------------------------------- 2 files changed, 10 insertions(+), 1071 deletions(-) diff --git a/api_v2.py b/api_v2.py index fb17dfd6..21be1a10 100644 --- a/api_v2.py +++ b/api_v2.py @@ -104,28 +104,22 @@ RESP: import os import sys import traceback -from typing import Generator, Union +from typing import Union now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import argparse -import subprocess -import wave import signal -import numpy as np -import soundfile as sf from fastapi import FastAPI, Response from fastapi.responses import StreamingResponse, JSONResponse import uvicorn -from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine from pydantic import BaseModel -import threading # print(sys.path) i18n = I18nAuto() @@ -187,168 +181,8 @@ class TTS_Request(BaseModel): min_chunk_length: int = 16 -def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - # Author: AkagawaTsurunaki - # Issue: - # Stack overflow probabilistically occurs - # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called - # using the Python library `soundfile` - # Note: - # This is an issue related to `libsndfile`, not this project itself. - # It happens when you generate a large audio tensor (about 499804 frames in my PC) - # and try to convert it to an ogg file. - # Related: - # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 - # https://github.com/libsndfile/libsndfile/issues/1023 - # https://github.com/bastibe/python-soundfile/issues/396 - # Suggestion: - # Or split the whole audio data into smaller audio segment to avoid stack overflow? - - def handle_pack_ogg(): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) - - - - # See: https://docs.python.org/3/library/threading.html - # The stack size of this thread is at least 32768 - # If stack overflow error still occurs, just modify the `stack_size`. - # stack_size = n * 4096, where n should be a positive integer. - # Here we chose n = 4096. - stack_size = 4096 * 4096 - try: - threading.stack_size(stack_size) - pack_ogg_thread = threading.Thread(target=handle_pack_ogg) - pack_ogg_thread.start() - pack_ogg_thread.join() - except RuntimeError as e: - # If changing the thread stack size is unsupported, a RuntimeError is raised. - print("RuntimeError: {}".format(e)) - print("Changing the thread stack size is unsupported.") - except ValueError as e: - # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. - print("ValueError: {}".format(e)) - print("The specified stack size is invalid.") - - return io_buffer - - -def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer.write(data.tobytes()) - return io_buffer - - -def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer = BytesIO() - sf.write(io_buffer, data, rate, format="wav") - return io_buffer - - -def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): - process = subprocess.Popen( - [ - "ffmpeg", - "-f", - "s16le", # 输入16位有符号小端整数PCM - "-ar", - str(rate), # 设置采样率 - "-ac", - "1", # 单声道 - "-i", - "pipe:0", # 从管道读取输入 - "-c:a", - "aac", # 音频编码器为AAC - "-b:a", - "192k", # 比特率 - "-vn", # 不包含视频 - "-f", - "adts", # 输出AAC数据流格式 - "pipe:1", # 将输出写入管道 - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, _ = process.communicate(input=data.tobytes()) - io_buffer.write(out) - return io_buffer - - -def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): - if media_type == "ogg": - io_buffer = pack_ogg(io_buffer, data, rate) - elif media_type == "aac": - io_buffer = pack_aac(io_buffer, data, rate) - elif media_type == "wav": - io_buffer = pack_wav(io_buffer, data, rate) - else: - io_buffer = pack_raw(io_buffer, data, rate) - io_buffer.seek(0) - return io_buffer - - -# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py -def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): - # This will create a wave header then append the frame input - # It should be first on a streaming wav file - # Other frames better should not have it (else you will hear some artifacts each chunk start) - wav_buf = BytesIO() - with wave.open(wav_buf, "wb") as vfout: - vfout.setnchannels(channels) - vfout.setsampwidth(sample_width) - vfout.setframerate(sample_rate) - vfout.writeframes(frame_input) - - wav_buf.seek(0) - return wav_buf.read() - - -def handle_control(command: str): - if command == "restart": - os.execl(sys.executable, sys.executable, *argv) - elif command == "exit": - os.kill(os.getpid(), signal.SIGTERM) - exit(0) - - -def check_params(req: dict): - text: str = req.get("text", "") - text_lang: str = req.get("text_lang", "") - ref_audio_path: str = req.get("ref_audio_path", "") - streaming_mode: bool = req.get("streaming_mode", False) - media_type: str = req.get("media_type", "wav") - prompt_lang: str = req.get("prompt_lang", "") - text_split_method: str = req.get("text_split_method", "cut5") - - if ref_audio_path in [None, ""]: - return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) - if text in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text is required"}) - if text_lang in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text_lang is required"}) - elif text_lang.lower() not in tts_config.languages: - return JSONResponse( - status_code=400, - content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, - ) - if prompt_lang in [None, ""]: - return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) - elif prompt_lang.lower() not in tts_config.languages: - return JSONResponse( - status_code=400, - content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, - ) - if media_type not in ["wav", "raw", "ogg", "aac"]: - return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) - # elif media_type == "ogg" and not streaming_mode: - # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) - - if text_split_method not in cut_method_names: - return JSONResponse( - status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} - ) - - return None +def _lower_or_none(value: str | None) -> str | None: + return value.lower() if isinstance(value, str) else value async def tts_handle(req: dict): @@ -435,11 +269,11 @@ async def tts_get_endpoint( ): req = { "text": text, - "text_lang": text_lang.lower(), + "text_lang": _lower_or_none(text_lang), "ref_audio_path": ref_audio_path, "aux_ref_audio_paths": aux_ref_audio_paths, "prompt_text": prompt_text, - "prompt_lang": prompt_lang.lower(), + "prompt_lang": _lower_or_none(prompt_lang), "top_k": top_k, "top_p": top_p, "temperature": temperature, diff --git a/api_v3.py b/api_v3.py index 37d66977..5c97995f 100644 --- a/api_v3.py +++ b/api_v3.py @@ -101,16 +101,10 @@ RESP: """ -import asyncio import os import sys -import time import traceback -import uuid -from collections import deque -from dataclasses import dataclass -from pathlib import Path -from typing import Generator, List, Union +from typing import List, Union now_dir = os.getcwd() sys.path.append(now_dir) @@ -121,33 +115,15 @@ from runtime_preload import preload_text_runtime_deps preload_text_runtime_deps() import argparse -import subprocess -import wave import signal -import numpy as np -import soundfile as sf -import torch from fastapi import FastAPI, Response from fastapi.responses import StreamingResponse, JSONResponse import uvicorn -from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( - SchedulerRequestSpec, - T2SActiveBatch, - T2SFinishedItem, - T2SRequestState, - merge_active_batches, - decode_one_step, - run_prefill_active_batch, - run_scheduler_continuous, -) from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel -import threading # print(sys.path) i18n = I18nAuto() @@ -248,879 +224,8 @@ class Scheduler_Submit_Request(BaseModel): timeout_sec: float = 30.0 -@dataclass -class SchedulerPendingJob: - request_id: str - state: T2SRequestState - done_event: threading.Event - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - enqueue_time: float - speed_factor: float - sample_steps: int - media_type: str - prepare_wall_ms: float = 0.0 - prepare_profile_total_ms: float = 0.0 - first_schedule_time: float | None = None - prefill_ms: float = 0.0 - merge_ms: float = 0.0 - decode_ms: float = 0.0 - finalize_wait_ms: float = 0.0 - synth_ms: float = 0.0 - pack_ms: float = 0.0 - decode_steps: int = 0 - result_ready_time: float | None = None - result: dict | None = None - sample_rate: int | None = None - audio_data: np.ndarray | None = None - error: str | None = None - - -@dataclass -class SchedulerFinalizeTask: - request_id: str - item: T2SFinishedItem - enqueued_time: float - - -class SchedulerDebugWorker: - def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): - self.tts = tts - self.max_steps = max_steps - self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 - self.prepare_coordinator = PrepareCoordinator(tts) - self.condition = threading.Condition() - self.prepare_inflight = 0 - self.prepare_peak_inflight = 0 - self.finalize_condition = threading.Condition() - self.finalize_pending_tasks: deque[SchedulerFinalizeTask] = deque() - self.finalize_pending_peak = 0 - self.finalize_inflight = 0 - self.finalize_inflight_peak = 0 - self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - self.job_map: dict[str, SchedulerPendingJob] = {} - self.total_finished = 0 - self.total_submitted = 0 - self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) - self.worker_thread.start() - self.finalize_threads = [ - threading.Thread( - target=self._run_finalize_loop, - name=f"t2s-scheduler-finalize-{worker_index}", - daemon=True, - ) - for worker_index in range(self.finalize_workers) - ] - for finalize_thread in self.finalize_threads: - finalize_thread.start() - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - raise RuntimeError("prepare_state sync path has been replaced by PrepareCoordinator") - - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - ) -> SchedulerPendingJob: - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - ) - with self.condition: - self.pending_jobs.append(job) - self.job_map[job.request_id] = job - self.total_submitted += 1 - self.condition.notify_all() - with self.finalize_condition: - self.finalize_condition.notify_all() - return job - - async def prepare_state_async(self, spec: SchedulerRequestSpec) -> T2SRequestState: - state, _, _ = await self.prepare_coordinator.prepare_state_profiled_async(spec, time.perf_counter()) - return state - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return await asyncio.gather(*[self.prepare_state_async(spec) for spec in specs]) - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - - def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in jobs: - tracked_job = self.job_map.get(job.request_id) - if tracked_job is not None and tracked_job.first_schedule_time is None: - tracked_job.first_schedule_time = started_at - - def _add_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - elapsed_ms = elapsed_s * 1000.0 - with self.condition: - for job in jobs: - tracked_job = self.job_map.get(job.request_id) - if tracked_job is not None: - tracked_job.prefill_ms += elapsed_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - elapsed_ms = elapsed_s * 1000.0 - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.merge_ms += elapsed_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - elapsed_ms = elapsed_s * 1000.0 - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.decode_ms += elapsed_ms - job.decode_steps += 1 - - def _add_finalize_wait_ms(self, request_ids: List[str], elapsed_ms: float) -> None: - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.finalize_wait_ms += elapsed_ms - - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - semantic_tokens = item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) - phones = job.state.phones.detach().clone().unsqueeze(0).to(self.tts.configs.device) - prompt_semantic = job.state.prompt_semantic.detach().clone() - prompt_phones = job.state.prompt_phones.detach().clone() - refer_spec = ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) - raw_audio = job.state.raw_audio.detach().clone() - audio_fragment = self.tts.synthesize_audio_request_local( - semantic_tokens=semantic_tokens, - phones=phones, - prompt_semantic=prompt_semantic, - prompt_phones=prompt_phones, - refer_spec=refer_spec, - raw_audio=raw_audio, - raw_sr=job.state.raw_sr, - speed=float(job.speed_factor), - sample_steps=int(job.sample_steps), - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - return self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - - def get_state(self) -> dict: - with self.finalize_condition: - finalize_pending = len(self.finalize_pending_tasks) - finalize_pending_peak = self.finalize_pending_peak - finalize_inflight = self.finalize_inflight - finalize_inflight_peak = self.finalize_inflight_peak - with self.condition: - bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() - ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() - bert_batch_worker = ( - None - if self.tts.prepare_bert_batch_worker is None - else self.tts.prepare_bert_batch_worker.snapshot() - ) - ref_semantic_batch_worker = ( - None - if self.tts.prepare_ref_semantic_batch_worker is None - else self.tts.prepare_ref_semantic_batch_worker.snapshot() - ) - prepare_coordinator_state = self.prepare_coordinator.snapshot() - return { - "pending_jobs": len(self.pending_jobs), - "running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids), - "prepare_inflight": prepare_coordinator_state["inflight"], - "prepare_peak_inflight": prepare_coordinator_state["peak_inflight"], - "finalize_pending": finalize_pending, - "finalize_pending_peak": finalize_pending_peak, - "finalize_inflight": finalize_inflight, - "finalize_inflight_peak": finalize_inflight_peak, - "finalize_workers": self.finalize_workers, - "finalize_mode": self.finalize_mode, - "finalize_batch_max_items": self.finalize_batch_max_items, - "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, - "prepare_request_executor_workers": 0, - "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), - "prepare_text_feature_workers": int(prepare_coordinator_state["text_feature_workers"]), - "prepare_ref_audio_workers": int(prepare_coordinator_state["ref_audio_workers"]), - "prepare_bert_stage": bert_stage, - "prepare_bert_batch_worker": bert_batch_worker, - "prepare_ref_audio_stage": ref_audio_stage, - "prepare_ref_semantic_batch_worker": ref_semantic_batch_worker, - "tracked_jobs": len(self.job_map), - "total_submitted": self.total_submitted, - "total_finished": self.total_finished, - "max_steps": self.max_steps, - "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), - } - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - tasks: List[SchedulerFinalizeTask] = [] - enqueued_time = time.perf_counter() - with self.condition: - for item in items: - job = self.job_map.get(item.request_id) - if job is not None: - tasks.append( - SchedulerFinalizeTask( - request_id=item.request_id, - item=item, - enqueued_time=enqueued_time, - ) - ) - if not tasks: - return - with self.finalize_condition: - self.finalize_pending_tasks.extend(tasks) - if len(self.finalize_pending_tasks) > self.finalize_pending_peak: - self.finalize_pending_peak = len(self.finalize_pending_tasks) - self.finalize_condition.notify_all() - - @staticmethod - def _finalize_batch_key(job: SchedulerPendingJob) -> tuple[float, int]: - return (round(float(job.speed_factor), 6), int(job.sample_steps)) - - def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: - with self.finalize_condition: - while not self.finalize_pending_tasks: - self.finalize_condition.wait() - if self.finalize_mode == "after_t2s_drain": - while not self._is_t2s_drained(): - self.finalize_condition.wait(timeout=self.micro_batch_wait_s) - task = self.finalize_pending_tasks.popleft() - selected_tasks = [task] - batch_key = None - with self.condition: - first_job = self.job_map.get(task.request_id) - if first_job is not None: - batch_key = self._finalize_batch_key(first_job) - batch_deadline = time.perf_counter() + self.finalize_batch_wait_s - while len(selected_tasks) < self.finalize_batch_max_items: - if batch_key is None: - break - matched_index = None - for pending_index, pending_task in enumerate(self.finalize_pending_tasks): - with self.condition: - pending_job = self.job_map.get(pending_task.request_id) - if pending_job is None: - matched_index = pending_index - break - if self._finalize_batch_key(pending_job) == batch_key: - matched_index = pending_index - break - if matched_index is not None: - selected_tasks.append(self.finalize_pending_tasks[matched_index]) - del self.finalize_pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.finalize_condition.wait(timeout=remaining) - self.finalize_inflight += len(selected_tasks) - if self.finalize_inflight > self.finalize_inflight_peak: - self.finalize_inflight_peak = self.finalize_inflight - return selected_tasks - - def _finalize_task_done(self, count: int) -> None: - with self.finalize_condition: - self.finalize_inflight = max(0, self.finalize_inflight - count) - - def _is_t2s_drained(self) -> bool: - with self.condition: - return ( - self.active_batch is None - and not self.pending_jobs - and self.prepare_inflight <= 0 - ) - - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - finished_at = time.perf_counter() - with self.condition: - if self.job_map.get(item.request_id) is not job: - return - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - job.result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.done_event.set() - self._notify_done_future(job) - self.job_map.pop(item.request_id, None) - self.total_finished += 1 - - def _synthesize_finished_audio_batch( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> List[tuple[int, np.ndarray]]: - semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] - phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] - refer_specs = [] - speeds = [] - sample_steps_list = [] - for job, _ in jobs_and_items: - refer_specs.append( - ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) - ) - speeds.append(float(job.speed_factor)) - sample_steps_list.append(int(job.sample_steps)) - audio_fragments = self.tts.synthesize_audio_requests_local_batched( - semantic_tokens_list=semantic_tokens_list, - phones_list=phones_list, - refer_specs=refer_specs, - speeds=speeds, - sample_steps_list=sample_steps_list, - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - results: List[tuple[int, np.ndarray]] = [] - for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): - results.append( - self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - ) - return results - - def _run_finalize_loop(self) -> None: - while True: - tasks = self._take_finalize_task_batch() - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - finalize_wait_request_ids: List[str] = [] - with self.condition: - for task in tasks: - job = self.job_map.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - finalize_wait_request_ids.append(task.request_id) - if not jobs_and_items: - continue - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_map.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self._finalize_task_done(len(tasks)) - - def _finalize_error(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is None: - continue - job.error = error - job.done_event.set() - self._notify_done_future(job) - self.job_map.pop(request_id, None) - self.total_finished += 1 - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(True) - - def _notify_done_future(self, job: SchedulerPendingJob) -> None: - if job.done_loop is None or job.done_future is None: - return - try: - job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) - except RuntimeError: - pass - - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs and self.active_batch is None: - self.condition.wait(timeout=self.micro_batch_wait_s) - elif wait_for_batch and self.pending_jobs: - self.condition.wait(timeout=self.micro_batch_wait_s) - if not self.pending_jobs: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - with self.finalize_condition: - self.finalize_condition.notify_all() - return pending - - def _run_loop(self) -> None: - while True: - wait_for_batch = self.active_batch is None - pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) - - if pending_jobs: - try: - self._sync_device() - prefill_start = time.perf_counter() - self._mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) - self._enqueue_finalize_finished(admitted_finished) - merge_start = time.perf_counter() - self.active_batch = merge_active_batches( - self.tts.t2s_model.model, - self.active_batch, - admitted_active_batch, - ) - self._add_merge_time( - [] if self.active_batch is None else list(self.active_batch.request_ids), - time.perf_counter() - merge_start, - ) - with self.finalize_condition: - self.finalize_condition.notify_all() - except Exception as exc: - self._finalize_error([job.request_id for job in pending_jobs], str(exc)) - - if self.active_batch is not None: - try: - active_request_ids = [state.request_id for state in self.active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - self.active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - self.active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) - self._enqueue_finalize_finished(step_finished) - with self.finalize_condition: - self.finalize_condition.notify_all() - except Exception as exc: - self._finalize_error(active_request_ids, str(exc)) - self.active_batch = None - with self.finalize_condition: - self.finalize_condition.notify_all() - continue - - if not pending_jobs: - with self.finalize_condition: - self.finalize_condition.notify_all() - time.sleep(self.micro_batch_wait_s) - - -scheduler_debug_worker = tts_engine.scheduler_worker - - -def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - # Author: AkagawaTsurunaki - # Issue: - # Stack overflow probabilistically occurs - # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called - # using the Python library `soundfile` - # Note: - # This is an issue related to `libsndfile`, not this project itself. - # It happens when you generate a large audio tensor (about 499804 frames in my PC) - # and try to convert it to an ogg file. - # Related: - # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 - # https://github.com/libsndfile/libsndfile/issues/1023 - # https://github.com/bastibe/python-soundfile/issues/396 - # Suggestion: - # Or split the whole audio data into smaller audio segment to avoid stack overflow? - - def handle_pack_ogg(): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) - - - - # See: https://docs.python.org/3/library/threading.html - # The stack size of this thread is at least 32768 - # If stack overflow error still occurs, just modify the `stack_size`. - # stack_size = n * 4096, where n should be a positive integer. - # Here we chose n = 4096. - stack_size = 4096 * 4096 - try: - threading.stack_size(stack_size) - pack_ogg_thread = threading.Thread(target=handle_pack_ogg) - pack_ogg_thread.start() - pack_ogg_thread.join() - except RuntimeError as e: - # If changing the thread stack size is unsupported, a RuntimeError is raised. - print("RuntimeError: {}".format(e)) - print("Changing the thread stack size is unsupported.") - except ValueError as e: - # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. - print("ValueError: {}".format(e)) - print("The specified stack size is invalid.") - - return io_buffer - - -def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer.write(data.tobytes()) - return io_buffer - - -def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer = BytesIO() - sf.write(io_buffer, data, rate, format="wav") - return io_buffer - - -def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): - process = subprocess.Popen( - [ - "ffmpeg", - "-f", - "s16le", # 输入16位有符号小端整数PCM - "-ar", - str(rate), # 设置采样率 - "-ac", - "1", # 单声道 - "-i", - "pipe:0", # 从管道读取输入 - "-c:a", - "aac", # 音频编码器为AAC - "-b:a", - "192k", # 比特率 - "-vn", # 不包含视频 - "-f", - "adts", # 输出AAC数据流格式 - "pipe:1", # 将输出写入管道 - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, _ = process.communicate(input=data.tobytes()) - io_buffer.write(out) - return io_buffer - - -def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): - if media_type == "ogg": - io_buffer = pack_ogg(io_buffer, data, rate) - elif media_type == "aac": - io_buffer = pack_aac(io_buffer, data, rate) - elif media_type == "wav": - io_buffer = pack_wav(io_buffer, data, rate) - else: - io_buffer = pack_raw(io_buffer, data, rate) - io_buffer.seek(0) - return io_buffer - - -# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py -def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): - # This will create a wave header then append the frame input - # It should be first on a streaming wav file - # Other frames better should not have it (else you will hear some artifacts each chunk start) - wav_buf = BytesIO() - with wave.open(wav_buf, "wb") as vfout: - vfout.setnchannels(channels) - vfout.setsampwidth(sample_width) - vfout.setframerate(sample_rate) - vfout.writeframes(frame_input) - - wav_buf.seek(0) - return wav_buf.read() - - -def handle_control(command: str): - if command == "restart": - os.execl(sys.executable, sys.executable, *argv) - elif command == "exit": - os.kill(os.getpid(), signal.SIGTERM) - exit(0) - - -def check_params(req: dict): - text: str = req.get("text", "") - text_lang: str = req.get("text_lang", "") - ref_audio_path: str = req.get("ref_audio_path", "") - streaming_mode: bool = req.get("streaming_mode", False) - media_type: str = req.get("media_type", "wav") - prompt_lang: str = req.get("prompt_lang", "") - text_split_method: str = req.get("text_split_method", "cut5") - - if ref_audio_path in [None, ""]: - return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) - if text in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text is required"}) - if text_lang in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text_lang is required"}) - elif text_lang.lower() not in tts_config.languages: - return JSONResponse( - status_code=400, - content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, - ) - if prompt_lang in [None, ""]: - return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) - elif prompt_lang.lower() not in tts_config.languages: - return JSONResponse( - status_code=400, - content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, - ) - if media_type not in ["wav", "raw", "ogg", "aac"]: - return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) - # elif media_type == "ogg" and not streaming_mode: - # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) - - if text_split_method not in cut_method_names: - return JSONResponse( - status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} - ) - - return None - - -def set_scheduler_seed(seed: int): - if seed in ["", None]: - return - seed = int(seed) - if seed < 0: - return - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def build_scheduler_request_specs(request_items: List[Scheduler_Debug_Request_Item]) -> List[SchedulerRequestSpec]: - specs: List[SchedulerRequestSpec] = [] - for index, item in enumerate(request_items): - payload = item.dict() - req = { - "text": payload["text"], - "text_lang": payload["text_lang"].lower(), - "ref_audio_path": payload["ref_audio_path"], - "aux_ref_audio_paths": None, - "prompt_text": payload["prompt_text"], - "prompt_lang": payload["prompt_lang"].lower(), - "top_k": payload["top_k"], - "top_p": payload["top_p"], - "temperature": payload["temperature"], - "text_split_method": "cut5", - "batch_size": 1, - "batch_threshold": 0.75, - "speed_factor": 1.0, - "split_bucket": False, - "fragment_interval": 0.3, - "seed": -1, - "media_type": "wav", - "streaming_mode": False, - "parallel_infer": False, - "repetition_penalty": payload["repetition_penalty"], - "sample_steps": 32, - "super_sampling": False, - "overlap_length": 2, - "min_chunk_length": 16, - } - check_res = check_params(req) - if check_res is not None: - detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) - raise ValueError(f"request[{index}] 参数非法: {detail}") - specs.append( - SchedulerRequestSpec( - request_id=payload["request_id"] or f"req_{index:03d}", - ref_audio_path=Path(payload["ref_audio_path"]), - prompt_text=payload["prompt_text"], - prompt_lang=payload["prompt_lang"].lower(), - text=payload["text"], - text_lang=payload["text_lang"].lower(), - top_k=int(payload["top_k"]), - top_p=float(payload["top_p"]), - temperature=float(payload["temperature"]), - repetition_penalty=float(payload["repetition_penalty"]), - early_stop_num=int(payload["early_stop_num"]), - ready_step=int(payload["ready_step"]), - ) - ) - return specs - - -def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: - return [ - { - "request_id": state.request_id, - "ready_step": int(state.ready_step), - "ref_audio_path": str(state.ref_audio_path), - "prompt_semantic_len": int(state.prompt_semantic.shape[0]), - "all_phone_len": int(state.all_phones.shape[0]), - "bert_len": int(state.all_bert_features.shape[-1]), - "norm_text": state.norm_text, - } - for state in states - ] - - -def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: - return [ - { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - } - for item in items - ] - - -def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: - payload = request.dict() - request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" - req = { - "text": payload["text"], - "text_lang": payload["text_lang"].lower(), - "ref_audio_path": payload["ref_audio_path"], - "aux_ref_audio_paths": None, - "prompt_text": payload["prompt_text"], - "prompt_lang": payload["prompt_lang"].lower(), - "top_k": payload["top_k"], - "top_p": payload["top_p"], - "temperature": payload["temperature"], - "text_split_method": "cut5", - "batch_size": 1, - "batch_threshold": 0.75, - "speed_factor": float(payload["speed_factor"]), - "split_bucket": False, - "fragment_interval": 0.3, - "seed": -1, - "media_type": payload["media_type"], - "streaming_mode": False, - "parallel_infer": False, - "repetition_penalty": payload["repetition_penalty"], - "sample_steps": int(payload["sample_steps"]), - "super_sampling": False, - "overlap_length": 2, - "min_chunk_length": 16, - } - check_res = check_params(req) - if check_res is not None: - detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) - raise ValueError(f"request 参数非法: {detail}") - return SchedulerRequestSpec( - request_id=request_id, - ref_audio_path=Path(payload["ref_audio_path"]), - prompt_text=payload["prompt_text"], - prompt_lang=payload["prompt_lang"].lower(), - text=payload["text"], - text_lang=payload["text_lang"].lower(), - top_k=int(payload["top_k"]), - top_p=float(payload["top_p"]), - temperature=float(payload["temperature"]), - repetition_penalty=float(payload["repetition_penalty"]), - early_stop_num=int(payload["early_stop_num"]), - ready_step=0, - ) +def _lower_or_none(value: str | None) -> str | None: + return value.lower() if isinstance(value, str) else value async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): @@ -1233,11 +338,11 @@ async def tts_get_endpoint( ): req = { "text": text, - "text_lang": text_lang.lower(), + "text_lang": _lower_or_none(text_lang), "ref_audio_path": ref_audio_path, "aux_ref_audio_paths": aux_ref_audio_paths, "prompt_text": prompt_text, - "prompt_lang": prompt_lang.lower(), + "prompt_lang": _lower_or_none(prompt_lang), "top_k": top_k, "top_p": top_p, "temperature": temperature, From 6a427b4f547066175f91c4d9fc1eaf302823a7a8 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Tue, 10 Mar 2026 21:25:14 +0800 Subject: [PATCH 11/24] Update TTS API to support asynchronous execution by replacing synchronous TTS calls with asynchronous counterparts in both api_v2.py and api_v3.py. Introduce new data classes in unified_engine.py for enhanced request handling and state management, improving overall system performance and maintainability. --- GPT_SoVITS/TTS_infer_pack/unified_engine.py | 1825 ++++++++++++++++--- api_v2.py | 2 +- api_v3.py | 2 +- 3 files changed, 1607 insertions(+), 222 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index 0a95015f..aed7b146 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -134,6 +134,93 @@ class DirectTTSExecution: streaming: bool audio_generator: Optional[Generator[bytes, None, None]] = None audio_bytes: Optional[bytes] = None + request_id: Optional[str] = None + + +@dataclass +class NormalizedEngineRequest: + request_id: str + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + aux_ref_audio_paths: List[str] | None = None + top_k: int = 15 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = False + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool | int = False + return_fragment: bool = False + fixed_length_chunk: bool = False + response_streaming: bool = False + parallel_infer: bool = False + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + timeout_sec: float | None = None + + def to_payload(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "text": self.text, + "text_lang": self.text_lang, + "ref_audio_path": self.ref_audio_path, + "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, + "prompt_text": self.prompt_text, + "prompt_lang": self.prompt_lang, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "text_split_method": self.text_split_method, + "batch_size": self.batch_size, + "batch_threshold": self.batch_threshold, + "speed_factor": self.speed_factor, + "split_bucket": self.split_bucket, + "fragment_interval": self.fragment_interval, + "seed": self.seed, + "media_type": self.media_type, + "streaming_mode": self.streaming_mode, + "return_fragment": self.return_fragment, + "fixed_length_chunk": self.fixed_length_chunk, + "response_streaming": self.response_streaming, + "parallel_infer": self.parallel_infer, + "repetition_penalty": self.repetition_penalty, + "sample_steps": self.sample_steps, + "super_sampling": self.super_sampling, + "overlap_length": self.overlap_length, + "min_chunk_length": self.min_chunk_length, + "early_stop_num": self.early_stop_num, + "ready_step": self.ready_step, + "timeout_sec": self.timeout_sec, + } + + def to_scheduler_spec(self) -> SchedulerRequestSpec: + return SchedulerRequestSpec( + request_id=self.request_id, + ref_audio_path=Path(self.ref_audio_path), + prompt_text=self.prompt_text, + prompt_lang=self.prompt_lang, + text=self.text, + text_lang=self.text_lang, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + early_stop_num=self.early_stop_num, + ready_step=self.ready_step, + ) @dataclass @@ -148,6 +235,57 @@ class SchedulerSubmitExecution: headers: Dict[str, str] +class EngineStatus: + NEW = "NEW" + QUEUED = "QUEUED" + VALIDATED = "VALIDATED" + CPU_PREPARING = "CPU_PREPARING" + GPU_PREPARING = "GPU_PREPARING" + READY_FOR_PREFILL = "READY_FOR_PREFILL" + ACTIVE_DECODE = "ACTIVE_DECODE" + READY_FOR_FINALIZE = "READY_FOR_FINALIZE" + FINALIZING = "FINALIZING" + STREAMING = "STREAMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class EngineRequestState: + request_id: str + api_mode: str + backend: str + media_type: str + response_streaming: bool + submit_ts: float + deadline_ts: float | None = None + status: str = EngineStatus.NEW + updated_ts: float = 0.0 + error: str | None = None + finish_reason: str | None = None + meta: Dict[str, Any] = field(default_factory=dict) + profile: Dict[str, Any] = field(default_factory=dict) + lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "api_mode": self.api_mode, + "backend": self.backend, + "media_type": self.media_type, + "response_streaming": self.response_streaming, + "status": self.status, + "submit_ts": self.submit_ts, + "updated_ts": self.updated_ts, + "deadline_ts": self.deadline_ts, + "error": self.error, + "finish_reason": self.finish_reason, + "meta": dict(self.meta), + "profile": dict(self.profile), + "lifecycle_timestamps": dict(self.lifecycle_timestamps), + } + + @dataclass class SchedulerPendingJob: request_id: str @@ -159,6 +297,7 @@ class SchedulerPendingJob: speed_factor: float sample_steps: int media_type: str + admission_wait_ms: float = 0.0 prepare_wall_ms: float = 0.0 prepare_profile_total_ms: float = 0.0 first_schedule_time: float | None = None @@ -174,6 +313,7 @@ class SchedulerPendingJob: sample_rate: int | None = None audio_data: np.ndarray | None = None error: str | None = None + engine_request_id: str | None = None @dataclass @@ -183,11 +323,25 @@ class SchedulerFinalizeTask: enqueued_time: float +@dataclass +class RuntimeStateCallbacks: + update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None + complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None + fail: Callable[[str, str], None] | None = None + + class UnifiedSchedulerWorker: - def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): + def __init__( + self, + tts: TTS, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ): self.tts = tts self.max_steps = int(max_steps) self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() self.prepare_coordinator = PrepareCoordinator(tts) self.condition = threading.Condition() self.prepare_inflight = 0 @@ -201,6 +355,8 @@ class UnifiedSchedulerWorker: self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) + self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) self.pending_jobs: List[SchedulerPendingJob] = [] self.active_batch: T2SActiveBatch | None = None self.job_map: Dict[str, SchedulerPendingJob] = {} @@ -219,16 +375,95 @@ class UnifiedSchedulerWorker: for finalize_thread in self.finalize_threads: finalize_thread.start() + def _current_decode_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(len(self.finalize_pending_tasks)) + prepare_inflight = int(self.prepare_coordinator.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + last_snapshot: Dict[str, int] = {} + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + last_snapshot = snapshot + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + ) + def snapshot(self) -> dict: with self.condition: finalize_pending = len(self.finalize_pending_tasks) prepare_state = self.prepare_coordinator.snapshot() + active_batch = self.active_batch + active_batch_summary = None + if active_batch is not None: + active_batch_summary = { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } return { "pending_jobs": len(self.pending_jobs), - "running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids), + "running_requests": 0 if active_batch is None else len(active_batch.request_ids), "prepare_inflight": prepare_state["inflight"], "prepare_peak_inflight": prepare_state["peak_inflight"], "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), "finalize_pending": finalize_pending, "finalize_pending_peak": self.finalize_pending_peak, "finalize_inflight": self.finalize_inflight, @@ -237,6 +472,9 @@ class UnifiedSchedulerWorker: "finalize_mode": self.finalize_mode, "finalize_batch_max_items": self.finalize_batch_max_items, "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": active_batch_summary, "total_submitted": self.total_submitted, "total_finished": self.total_finished, "drained": self.is_drained(), @@ -282,7 +520,10 @@ class UnifiedSchedulerWorker: prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None = None, done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, ) -> SchedulerPendingJob: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) job = SchedulerPendingJob( request_id=state.request_id, state=state, @@ -293,14 +534,25 @@ class UnifiedSchedulerWorker: speed_factor=float(speed_factor), sample_steps=int(sample_steps), media_type=media_type, + admission_wait_ms=float(admission_wait_ms), prepare_wall_ms=float(prepare_wall_ms), prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, ) with self.condition: self.pending_jobs.append(job) self.job_map[job.request_id] = job self.total_submitted += 1 self.condition.notify_all() + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) with self.finalize_condition: self.finalize_condition.notify_all() return job @@ -335,6 +587,11 @@ class UnifiedSchedulerWorker: if tracked_job is None: continue tracked_job.first_schedule_time = float(started_at) + self._runtime_update( + tracked_job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": tracked_job.request_id, "prefill_started_at": float(started_at)}, + ) def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: delta_ms = float(elapsed_s) * 1000.0 @@ -360,12 +617,17 @@ class UnifiedSchedulerWorker: delta_ms = float(elapsed_s) * 1000.0 if not request_ids: return + activate_request_ids: List[str] = [] with self.condition: for request_id in request_ids: job = self.job_map.get(request_id) if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) job.decode_ms += delta_ms job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: if not request_ids: @@ -382,6 +644,17 @@ class UnifiedSchedulerWorker: enqueued_at = time.perf_counter() with self.finalize_condition: for item in items: + job = self.job_map.get(item.request_id) + if job is not None: + self._runtime_update( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + }, + ) self.finalize_pending_tasks.append( SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) ) @@ -421,12 +694,16 @@ class UnifiedSchedulerWorker: self.finalize_condition.wait(timeout=remaining) self.finalize_inflight += len(selected_tasks) self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) + with self.condition: + self.condition.notify_all() return selected_tasks def _finalize_task_done(self, count: int) -> None: with self.finalize_condition: self.finalize_inflight = max(0, self.finalize_inflight - count) self.finalize_condition.notify_all() + with self.condition: + self.condition.notify_all() def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: audio_fragment = self.tts.synthesize_audio_request_local( @@ -524,6 +801,7 @@ class UnifiedSchedulerWorker: "semantic_len": int(item.semantic_tokens.shape[0]), "finish_idx": int(item.finish_idx), "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), "prepare_ms": job.prepare_wall_ms, "prepare_wall_ms": job.prepare_wall_ms, "prepare_profile_total_ms": job.prepare_profile_total_ms, @@ -546,6 +824,16 @@ class UnifiedSchedulerWorker: self.job_map.pop(item.request_id, None) self.total_finished += 1 self.condition.notify_all() + self._runtime_complete( + job.engine_request_id, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + }, + ) def _finalize_error(self, request_ids: List[str], error: str) -> None: if not request_ids: @@ -560,6 +848,7 @@ class UnifiedSchedulerWorker: self._notify_done_future(job) self.job_map.pop(request_id, None) self.total_finished += 1 + self._runtime_fail(job.engine_request_id, error) self.condition.notify_all() @staticmethod @@ -577,6 +866,21 @@ class UnifiedSchedulerWorker: except RuntimeError: pass + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.complete is None: + return + self.runtime_callbacks.complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + if request_id is None or self.runtime_callbacks.fail is None: + return + self.runtime_callbacks.fail(request_id, error) + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: with self.condition: if not self.pending_jobs and self.active_batch is None: @@ -605,6 +909,15 @@ class UnifiedSchedulerWorker: now = time.perf_counter() for task in tasks: self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) self._sync_device() synth_start = time.perf_counter() if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: @@ -791,15 +1104,586 @@ class UnifiedTTSEngine: t2s_weights_path=str(self.tts.configs.t2s_weights_path), vits_weights_path=str(self.tts.configs.vits_weights_path), ) - self.scheduler_worker = UnifiedSchedulerWorker(tts, max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) + self.request_registry_lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_request_limit = max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + self.scheduler_worker = UnifiedSchedulerWorker( + tts, + max_steps=max_steps, + micro_batch_wait_ms=micro_batch_wait_ms, + runtime_callbacks=RuntimeStateCallbacks( + update=self._update_request_state, + complete=self._complete_request_state, + fail=self._fail_request_state, + ), + ) self.direct_tts_lock = threading.RLock() self.management_lock = threading.RLock() + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.request_registry_lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_request_limit: + self.recent_requests.pop() + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + now = time.perf_counter() + with self.request_registry_lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = status + state.updated_ts = now + state.lifecycle_timestamps[status] = now + if extra: + backend = extra.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = extra.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = extra.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.request_registry_lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + backend = extra.get("backend") + if backend is not None: + state.backend = str(backend) + finish_reason = extra.get("finish_reason") + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = extra.get("error") + if error is not None: + state.error = str(error) + merged = dict(extra) + merged.pop("backend", None) + merged.pop("finish_reason", None) + merged.pop("error", None) + state.profile.update(merged) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.request_registry_lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + if extra: + finish_reason = extra.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + state.profile.update(extra) + self._move_to_recent_locked(state) + + def _fail_request_state(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.request_registry_lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + with self.request_registry_lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": self.recent_request_limit, + "active_requests": active, + "recent_requests": recent, + } + + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + if component is None or not hasattr(component, "snapshot"): + return None + try: + return dict(component.snapshot()) + except Exception: + return None + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) + ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) + text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) + + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": int(worker_state.get("pending_jobs", 0)), + "worker_decode_active_size": int(worker_state.get("running_requests", 0)), + "worker_prepare_inflight": int(worker_state.get("prepare_inflight", 0)), + "worker_finalize_pending": int(worker_state.get("finalize_pending", 0)), + "worker_finalize_inflight": int(worker_state.get("finalize_inflight", 0)), + "admission_config": { + "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), + "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), + }, + "active_batch": dict(worker_state.get("active_batch") or {}), + "prepare_state": dict(worker_state.get("prepare_state") or {}), + "bert_batch_worker_state": bert_worker_state, + "ref_semantic_worker_state": ref_semantic_worker_state, + "text_preprocessor_state": text_preprocessor_state, + } + + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + requested = set(request_ids) + results: List[Dict[str, Any]] = [] + with self.request_registry_lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + for state in self.recent_requests: + if state.request_id in requested and all(item["request_id"] != state.request_id for item in results): + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def _has_active_request(self, request_id: str) -> bool: + with self.request_registry_lock: + return request_id in self.active_requests + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + text = payload.get("text") + prompt_text = payload.get("prompt_text") + return { + "text_len": 0 if text is None else len(str(text)), + "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), + "text_lang": payload.get("text_lang"), + "prompt_lang": payload.get("prompt_lang"), + "ref_audio_path": payload.get("ref_audio_path"), + } + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + total = 0.0 + for item in items: + value = item.get(key, 0.0) + if isinstance(value, (int, float)): + total += float(value) + return total + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + results: List[Dict[str, Any]] = [] + for index, segment_text in enumerate(segment_texts): + prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} + worker_item = worker_profiles[index] if index < len(worker_profiles) else {} + prepare_profile = dict(prepare_item.get("prepare_profile", {})) + results.append( + { + "segment_index": index, + "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), + "text_len": len(str(segment_text)), + "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), + "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), + "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), + "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), + "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), + "merge_ms": float(worker_item.get("merge_ms", 0.0)), + "decode_ms": float(worker_item.get("decode_ms", 0.0)), + "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), + "synth_ms": float(worker_item.get("synth_ms", 0.0)), + "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), + "decode_steps": int(worker_item.get("decode_steps", 0)), + "semantic_len": int(worker_item.get("semantic_len", 0)), + "finish_reason": worker_item.get("finish_reason"), + "norm_text": prepare_profile.get("norm_text"), + } + ) + return results + + def _build_direct_scheduler_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + pack_ms: float, + response_overhead_ms: float, + ) -> Dict[str, Any]: + segment_trace = self._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") + prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") + decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") + queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") + prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") + merge_ms = self._sum_profile_field(worker_profiles, "merge_ms") + decode_ms = self._sum_profile_field(worker_profiles, "decode_ms") + finalize_wait_ms = self._sum_profile_field(worker_profiles, "finalize_wait_ms") + synth_ms = self._sum_profile_field(worker_profiles, "synth_ms") + worker_total_ms = self._sum_profile_field(worker_profiles, "worker_total_ms") + decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) + semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - worker_total_ms - pack_ms - response_overhead_ms, + ) + return { + "backend": backend, + "backend_mode": backend, + "segment_count": len(segment_texts), + "sample_rate": int(sample_rate), + "audio_bytes": int(audio_bytes), + "request_total_ms": request_total_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "decode_admission_wait_ms": decode_admission_wait_ms, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": prefill_ms, + "merge_ms": merge_ms, + "decode_ms": decode_ms, + "finalize_wait_ms": finalize_wait_ms, + "synth_ms": synth_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "worker_total_ms": worker_total_ms, + "request_other_ms": request_other_ms, + "decode_steps": decode_steps, + "semantic_len": semantic_len, + "prepare_segments": list(prepare_profiles), + "worker_segments": list(worker_profiles), + "segment_trace": segment_trace, + "prepare_aggregate": self._aggregate_numeric_dicts(prepare_profile_dicts), + } + + def _build_legacy_direct_profile( + self, + *, + backend: str, + fallback_reason: str | None, + request_start: float, + finished_at: float, + sample_rate: int | None = None, + audio_bytes: int = 0, + pack_ms: float = 0.0, + chunk_count: int = 0, + stream_total_bytes: int = 0, + first_chunk_ms: float | None = None, + ) -> Dict[str, Any]: + request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) + legacy_infer_ms = max(0.0, request_total_ms - pack_ms) + return { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "request_total_ms": request_total_ms, + "prepare_ms": 0.0, + "queue_wait_ms": 0.0, + "prefill_ms": 0.0, + "merge_ms": 0.0, + "decode_ms": 0.0, + "finalize_wait_ms": 0.0, + "synth_ms": 0.0, + "pack_ms": pack_ms, + "worker_total_ms": legacy_infer_ms, + "request_other_ms": 0.0, + "legacy_infer_ms": legacy_infer_ms, + "sample_rate": int(sample_rate) if sample_rate is not None else None, + "audio_bytes": int(audio_bytes), + "chunk_count": int(chunk_count), + "stream_total_bytes": int(stream_total_bytes), + "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), + } + + def _build_scheduler_submit_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + prepare_spec_build_ms: float, + prepare_wall_ms: float, + prepare_executor_queue_ms: float, + prepare_executor_run_ms: float, + prepare_profile_total_ms: float, + prepare_profile_wall_ms: float, + prepare_other_ms: float, + api_after_prepare_ms: float, + api_wait_result_ms: float, + pack_ms: float, + response_overhead_ms: float, + worker_profile: Dict[str, Any], + ) -> Dict[str, Any]: + worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + ) + result = { + "backend": backend, + "backend_mode": backend, + "audio_bytes": int(audio_bytes), + "sample_rate": int(sample_rate), + "prepare_spec_build_ms": prepare_spec_build_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_executor_queue_ms": prepare_executor_queue_ms, + "prepare_executor_run_ms": prepare_executor_run_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile_wall_ms": prepare_profile_wall_ms, + "prepare_other_ms": prepare_other_ms, + "api_after_prepare_ms": api_after_prepare_ms, + "api_wait_result_ms": api_wait_result_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "request_total_ms": request_total_ms, + "request_other_ms": request_other_ms, + } + result.update({key: value for key, value in worker_profile.items()}) + return result + + @staticmethod + def _format_ms_header(value: Any) -> str: + return f"{float(value):.3f}" + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + prepare_profile = dict(profile.get("prepare_profile", {})) + headers = { + "X-Request-Id": request_id, + "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), + "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), + "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), + "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), + "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), + "X-Prepare-Executor-Queue-Ms": self._format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), + "X-Prepare-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), + "X-Prepare-Executor-Run-Ms": self._format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), + "X-Prepare-Profile-Total-Ms": self._format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), + "X-Prepare-Profile-Wall-Ms": self._format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), + "X-Prepare-Other-Ms": self._format_ms_header(profile.get("prepare_other_ms", 0.0)), + "X-Api-After-Prepare-Ms": self._format_ms_header(profile.get("api_after_prepare_ms", 0.0)), + "X-Prefill-Ms": self._format_ms_header(profile.get("prefill_ms", 0.0)), + "X-Merge-Ms": self._format_ms_header(profile.get("merge_ms", 0.0)), + "X-Decode-Ms": self._format_ms_header(profile.get("decode_ms", 0.0)), + "X-Finalize-Wait-Ms": self._format_ms_header(profile.get("finalize_wait_ms", 0.0)), + "X-Synth-Ms": self._format_ms_header(profile.get("synth_ms", 0.0)), + "X-Worker-Residual-Ms": self._format_ms_header(profile.get("worker_residual_ms", 0.0)), + "X-Worker-Other-Ms": self._format_ms_header(profile.get("worker_other_ms", 0.0)), + "X-Pack-Ms": self._format_ms_header(profile.get("pack_ms", 0.0)), + "X-Worker-Total-Ms": self._format_ms_header(profile.get("worker_total_ms", 0.0)), + "X-Api-Wait-Result-Ms": self._format_ms_header(profile.get("api_wait_result_ms", 0.0)), + "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), + "X-Sample-Rate": str(int(sample_rate)), + "X-Response-Overhead-Ms": self._format_ms_header(profile.get("response_overhead_ms", 0.0)), + "X-Request-Other-Ms": self._format_ms_header(profile.get("request_other_ms", 0.0)), + "X-Request-Total-Ms": self._format_ms_header(profile.get("request_total_ms", 0.0)), + } + headers.update( + { + "X-Prepare-Prompt-Text-Ms": self._format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), + "X-Prepare-Target-Text-Ms": self._format_ms_header(prepare_profile.get("text_features_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), + "X-Prepare-Prompt-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), + "X-Prepare-Target-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), + "X-Prepare-Prompt-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), + "X-Prepare-Target-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), + "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), + "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-CPU-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), + "X-Prepare-Ref-Spec-Ms": self._format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), + "X-Prepare-Ref-Spec-Wait-Ms": self._format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), + "X-Prepare-Ref-Bundle-Ms": self._format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), + "X-Prepare-Tensorize-Ms": self._format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), + "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), + } + ) + return headers + + def _build_scheduler_debug_request_profile( + self, + *, + state: T2SRequestState, + item: T2SFinishedItem, + batch_request_count: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + batch_request_total_ms: float, + ) -> Dict[str, Any]: + prepare_profile = dict(state.prepare_profile) + prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) + return { + "backend": "scheduler_debug", + "backend_mode": "scheduler_debug", + "batch_request_count": int(batch_request_count), + "batch_prepare_wall_ms": float(prepare_batch_wall_ms), + "batch_decode_wall_ms": float(decode_batch_wall_ms), + "batch_request_total_ms": float(batch_request_total_ms), + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), + "prepare_profile": prepare_profile, + "decode_steps": int(item.finish_idx), + "finish_idx": int(item.finish_idx), + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_reason": item.finish_reason, + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + } + + @staticmethod + def _build_scheduler_debug_batch_profile( + *, + request_count: int, + max_steps: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + request_total_ms: float, + finished_items: Sequence[T2SFinishedItem], + ) -> Dict[str, Any]: + finish_reason_counts: Dict[str, int] = {} + total_semantic_len = 0 + for item in finished_items: + finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 + total_semantic_len += int(item.semantic_tokens.shape[0]) + return { + "request_count": int(request_count), + "max_steps": int(max_steps), + "prepare_batch_wall_ms": float(prepare_batch_wall_ms), + "decode_batch_wall_ms": float(decode_batch_wall_ms), + "request_total_ms": float(request_total_ms), + "total_semantic_len": int(total_semantic_len), + "finish_reason_counts": finish_reason_counts, + } + def _normalize_lang(self, value: str | None) -> str | None: if value in [None, ""]: return value return str(value).lower() + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + totals: Dict[str, float] = {} + for item in items: + for key, value in item.items(): + if isinstance(value, (int, float)): + totals[key] = totals.get(key, 0.0) + float(value) + return totals + def _apply_default_reference(self, req: dict) -> dict: normalized = dict(req) default_ref = self.reference_registry.get_default() @@ -837,6 +1721,109 @@ class UnifiedTTSEngine: return f"text_split_method:{text_split_method} is not supported" return None + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return { + "request_id": None, + "text": None, + "text_lang": None, + "ref_audio_path": None, + "aux_ref_audio_paths": None, + "prompt_text": "", + "prompt_lang": None, + "top_k": 15, + "top_p": 1.0, + "temperature": 1.0, + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "return_fragment": False, + "fixed_length_chunk": False, + "response_streaming": False, + "parallel_infer": False, + "repetition_penalty": 1.35, + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + "early_stop_num": -1, + "ready_step": 0, + "timeout_sec": None, + } + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + if isinstance(payload, NormalizedEngineRequest): + normalized_payload = payload.to_payload() + else: + normalized_payload = self._base_request_defaults() + normalized_payload.update(dict(payload)) + if request_id not in [None, ""]: + normalized_payload["request_id"] = str(request_id) + elif normalized_payload.get("request_id") in [None, ""]: + raise ValueError("request_id is required after normalization") + normalized_payload = self._apply_default_reference(normalized_payload) + if normalize_streaming: + normalized_payload = self._normalize_streaming_mode(normalized_payload) + error = self.check_params(normalized_payload) + if error is not None: + raise ValueError(f"{error_prefix}{error}") + timeout_sec = normalized_payload.get("timeout_sec") + if timeout_sec in [None, ""]: + parsed_timeout = None + else: + parsed_timeout = float(timeout_sec) + aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") + if aux_ref_audio_paths in [None, "", []]: + normalized_aux_ref_audio_paths = None + else: + normalized_aux_ref_audio_paths = [str(item) for item in aux_ref_audio_paths] + return NormalizedEngineRequest( + request_id=str(normalized_payload["request_id"]), + text=str(normalized_payload["text"]), + text_lang=str(normalized_payload["text_lang"]), + ref_audio_path=str(normalized_payload["ref_audio_path"]), + prompt_lang=str(normalized_payload["prompt_lang"]), + prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), + aux_ref_audio_paths=normalized_aux_ref_audio_paths, + top_k=int(normalized_payload["top_k"]), + top_p=float(normalized_payload["top_p"]), + temperature=float(normalized_payload["temperature"]), + repetition_penalty=float(normalized_payload["repetition_penalty"]), + early_stop_num=int(normalized_payload.get("early_stop_num", -1)), + ready_step=int(normalized_payload.get("ready_step", 0)), + text_split_method=str(normalized_payload["text_split_method"]), + batch_size=int(normalized_payload["batch_size"]), + batch_threshold=float(normalized_payload["batch_threshold"]), + split_bucket=bool(normalized_payload["split_bucket"]), + speed_factor=float(normalized_payload["speed_factor"]), + fragment_interval=float(normalized_payload["fragment_interval"]), + seed=int(normalized_payload["seed"]), + media_type=str(normalized_payload["media_type"]), + streaming_mode=normalized_payload["streaming_mode"], + return_fragment=bool(normalized_payload.get("return_fragment", False)), + fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), + response_streaming=bool(normalized_payload.get("response_streaming", False)), + parallel_infer=bool(normalized_payload["parallel_infer"]), + sample_steps=int(normalized_payload["sample_steps"]), + super_sampling=bool(normalized_payload["super_sampling"]), + overlap_length=int(normalized_payload["overlap_length"]), + min_chunk_length=int(normalized_payload["min_chunk_length"]), + timeout_sec=parsed_timeout, + ) + @staticmethod def _normalize_streaming_mode(req: dict) -> dict: normalized = dict(req) @@ -867,139 +1854,468 @@ class UnifiedTTSEngine: normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) return normalized - def _iter_direct_tts_bytes(self, req: dict) -> Generator[bytes, None, None]: - media_type = req["media_type"] - with self.direct_tts_lock: - tts_generator = self.tts.run(req) - first_chunk = True - current_media_type = media_type - for sr, chunk in tts_generator: - if first_chunk and media_type == "wav": - yield wave_header_chunk(sample_rate=sr) - current_media_type = "raw" - first_chunk = False - yield pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return aux_ref_audio_paths not in [None, [], ()] - def run_direct_tts(self, req: dict) -> DirectTTSExecution: - normalized = self._normalize_streaming_mode(self._apply_default_reference(req)) - error = self.check_params(normalized) - if error is not None: - raise ValueError(error) - media_type = normalized.get("media_type", "wav") - if normalized["response_streaming"]: - return DirectTTSExecution( - media_type=media_type, - streaming=True, - audio_generator=self._iter_direct_tts_bytes(normalized), + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + if normalized.response_streaming: + if normalized.return_fragment or normalized.fixed_length_chunk: + return "legacy_direct_fragment", "fragment_streaming_mode" + return "legacy_direct_streaming", "streaming_mode" + if self._is_aux_ref_enabled(normalized.aux_ref_audio_paths): + return "legacy_direct_aux_ref", "aux_ref_audio_paths" + if normalized.super_sampling: + return "legacy_direct_super_sampling", "super_sampling" + if normalized.prompt_text in [None, ""]: + return "legacy_direct_missing_prompt", "missing_prompt_text" + return "scheduler_v1_direct", None + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + payload = normalized.to_payload() + media_type = normalized.media_type + request_id = normalized.request_id + request_start = time.perf_counter() + chunk_count = 0 + stream_total_bytes = 0 + first_chunk_ms: float | None = None + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + try: + with self.direct_tts_lock: + tts_generator = self.tts.run(payload) + first_chunk = True + current_media_type = media_type + for sr, chunk in tts_generator: + if first_chunk: + first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + self._update_request_state( + request_id, + EngineStatus.STREAMING, + { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "sample_rate": int(sr), + }, + ) + if first_chunk and media_type == "wav": + header = wave_header_chunk(sample_rate=sr) + chunk_count += 1 + stream_total_bytes += len(header) + yield header + current_media_type = "raw" + first_chunk = False + elif first_chunk: + first_chunk = False + packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_chunk) + yield packed_chunk + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + self._complete_request_state( + request_id, + dict( + self._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + audio_bytes=stream_total_bytes, + chunk_count=chunk_count, + stream_total_bytes=stream_total_bytes, + first_chunk_ms=first_chunk_ms, + ), + streaming_completed=True, + ), + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + if isinstance(req, NormalizedEngineRequest): + normalized = req + else: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, ) - with self.direct_tts_lock: - tts_generator = self.tts.run(normalized) - sr, audio_data = next(tts_generator) + backend, _ = self._select_direct_backend(normalized) + return backend == "scheduler_v1_direct" + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized + return self.tts.text_preprocessor.pre_seg_text( + str(payload["text"]), + str(payload["text_lang"]), + str(payload.get("text_split_method", "cut5")), + ) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + payload = normalized.to_payload() + payload["request_id"] = request_id + payload["text"] = text + payload["streaming_mode"] = False + payload["return_fragment"] = False + payload["fixed_length_chunk"] = False + payload["response_streaming"] = False + return self._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + request_start = time.perf_counter() + request_id = normalized.request_id + media_type = normalized.media_type + segment_texts = self._segment_direct_text(normalized) + if not segment_texts: + raise ValueError("text preprocessing returned no valid segments") + self._update_request_state( + request_id, + EngineStatus.CPU_PREPARING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, + ) + segment_specs: List[SchedulerRequestSpec] = [] + for segment_index, segment_text in enumerate(segment_texts): + segment_request = self._build_segment_request( + normalized, + request_id=f"{request_id}_seg_{segment_index:03d}", + text=segment_text, + ) + segment_specs.append(self.build_scheduler_submit_spec(segment_request)) + + prepared_items = await asyncio.gather( + *[ + self.scheduler_worker.prepare_state_profiled_async(spec, time.perf_counter()) + for spec in segment_specs + ] + ) + prepare_profiles: List[Dict[str, Any]] = [] + jobs: List[SchedulerPendingJob] = [] + loop = asyncio.get_running_loop() + done_futures: List[asyncio.Future] = [] + for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profiles.append( + { + "request_id": spec.request_id, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": dict(state.prepare_profile), + } + ) + done_future = loop.create_future() + done_futures.append(done_future) + jobs.append( + await self.scheduler_worker.submit_async( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, + ) + ) + self._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + { + "backend": "scheduler_v1_direct", + "backend_mode": "scheduler_v1_direct", + "segment_count": len(segment_specs), + "prepare_aggregate": self._aggregate_numeric_dicts( + [item["prepare_profile"] for item in prepare_profiles] + ), + }, + ) + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) + await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec) + + sample_rate: int | None = None + audio_parts: List[np.ndarray] = [] + worker_profiles: List[Dict[str, Any]] = [] + fragment_interval = float(normalized.fragment_interval) + silence_chunk: Optional[np.ndarray] = None + for job in jobs: + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + if sample_rate is None: + sample_rate = int(job.sample_rate) + silence_samples = int(fragment_interval * float(sample_rate)) + if silence_samples > 0: + silence_chunk = np.zeros(silence_samples, dtype=np.int16) + elif int(job.sample_rate) != sample_rate: + raise RuntimeError("segment sample rate mismatch") + audio_parts.append(job.audio_data) + if silence_chunk is not None: + audio_parts.append(silence_chunk.copy()) + worker_profiles.append(dict(job.result)) + if sample_rate is None or not audio_parts: + raise RuntimeError("direct scheduler backend produced no audio") + self._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + merged_audio = np.concatenate(audio_parts, axis=0) + pack_start = time.perf_counter() + audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + direct_profile = self._build_direct_scheduler_profile( + backend="scheduler_v1_direct", + request_start=request_start, + response_ready_at=time.perf_counter(), + audio_bytes=len(audio_bytes), + sample_rate=int(sample_rate), + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=pack_ms, + response_overhead_ms=0.0, + ) + self._complete_request_state( + request_id, + dict(direct_profile, streaming_completed=False), + ) return DirectTTSExecution( media_type=media_type, streaming=False, - audio_bytes=pack_audio(BytesIO(), audio_data, sr, media_type).getvalue(), + audio_bytes=audio_bytes, + request_id=request_id, + ) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + normalized_payload = normalized.to_payload() + request_id = normalized.request_id + media_type = normalized.media_type + request_start = time.perf_counter() + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + with self.direct_tts_lock: + tts_generator = self.tts.run(normalized_payload) + try: + sr, audio_data = next(tts_generator) + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + self._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + pack_start = time.perf_counter() + packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + self._complete_request_state( + request_id, + dict( + self._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + sample_rate=int(sr), + audio_bytes=len(packed_audio), + pack_ms=pack_ms, + ), + streaming_completed=False, + ), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=packed_audio, + request_id=request_id, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + if normalized.response_streaming: + return DirectTTSExecution( + media_type=normalized.media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=normalized.request_id, + ) + return await asyncio.to_thread( + self._run_legacy_direct_tts_blocking, + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self._select_direct_backend(normalized) + self._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + deadline_ts=( + time.perf_counter() + float(normalized.timeout_sec) + if normalized.timeout_sec is not None + else None + ), + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend == "scheduler_v1_direct": + try: + return await self._run_direct_tts_via_scheduler(normalized) + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + return await self._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self._select_direct_backend(normalized) + if not self._has_active_request(request_id): + self._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend != "scheduler_v1_direct": + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + normalized_payload = normalized.to_payload() + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", ) def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: specs: List[SchedulerRequestSpec] = [] for index, payload in enumerate(request_items): - req = self._apply_default_reference( - { - "text": payload["text"], - "text_lang": self._normalize_lang(payload["text_lang"]), - "ref_audio_path": payload["ref_audio_path"], - "aux_ref_audio_paths": None, - "prompt_text": payload["prompt_text"], - "prompt_lang": self._normalize_lang(payload["prompt_lang"]), - "top_k": payload["top_k"], - "top_p": payload["top_p"], - "temperature": payload["temperature"], - "text_split_method": "cut5", - "batch_size": 1, - "batch_threshold": 0.75, - "speed_factor": 1.0, - "split_bucket": False, - "fragment_interval": 0.3, - "seed": -1, - "media_type": "wav", - "streaming_mode": False, - "parallel_infer": False, - "repetition_penalty": payload["repetition_penalty"], - "sample_steps": 32, - "super_sampling": False, - "overlap_length": 2, - "min_chunk_length": 16, - } - ) - error = self.check_params(req) - if error is not None: - raise ValueError(f"request[{index}] 参数非法: {error}") - specs.append( - SchedulerRequestSpec( - request_id=payload.get("request_id") or f"req_{index:03d}", - ref_audio_path=Path(req["ref_audio_path"]), - prompt_text=payload["prompt_text"], - prompt_lang=req["prompt_lang"], - text=payload["text"], - text_lang=req["text_lang"], - top_k=int(payload["top_k"]), - top_p=float(payload["top_p"]), - temperature=float(payload["temperature"]), - repetition_penalty=float(payload["repetition_penalty"]), - early_stop_num=int(payload.get("early_stop_num", -1)), - ready_step=int(payload.get("ready_step", 0)), - ) + normalized = self._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"req_{index:03d}"), + error_prefix=f"request[{index}] 参数非法: ", ) + specs.append(normalized.to_scheduler_spec()) return specs - def build_scheduler_submit_spec(self, payload: dict) -> SchedulerRequestSpec: - request_id = payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}" - req = self._apply_default_reference( - { - "text": payload["text"], - "text_lang": self._normalize_lang(payload["text_lang"]), - "ref_audio_path": payload["ref_audio_path"], - "aux_ref_audio_paths": None, - "prompt_text": payload["prompt_text"], - "prompt_lang": self._normalize_lang(payload["prompt_lang"]), - "top_k": payload["top_k"], - "top_p": payload["top_p"], - "temperature": payload["temperature"], - "text_split_method": "cut5", - "batch_size": 1, - "batch_threshold": 0.75, - "speed_factor": float(payload["speed_factor"]), - "split_bucket": False, - "fragment_interval": 0.3, - "seed": -1, - "media_type": payload["media_type"], - "streaming_mode": False, - "parallel_infer": False, - "repetition_penalty": payload["repetition_penalty"], - "sample_steps": int(payload["sample_steps"]), - "super_sampling": False, - "overlap_length": 2, - "min_chunk_length": 16, - } - ) - error = self.check_params(req) - if error is not None: - raise ValueError(f"request 参数非法: {error}") - return SchedulerRequestSpec( - request_id=request_id, - ref_audio_path=Path(req["ref_audio_path"]), - prompt_text=payload["prompt_text"], - prompt_lang=req["prompt_lang"], - text=payload["text"], - text_lang=req["text_lang"], - top_k=int(payload["top_k"]), - top_p=float(payload["top_p"]), - temperature=float(payload["temperature"]), - repetition_penalty=float(payload["repetition_penalty"]), - early_stop_num=int(payload.get("early_stop_num", -1)), - ready_step=0, + def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + normalized = self._normalize_engine_request( + payload, + request_id=( + payload.request_id + if isinstance(payload, NormalizedEngineRequest) + else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") + ), ) + return normalized.to_scheduler_spec() @staticmethod def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: @@ -1029,30 +2345,138 @@ class UnifiedTTSEngine: ] async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + request_start = time.perf_counter() set_scheduler_seed(seed) specs = self.build_scheduler_request_specs(request_items) - states = await self.scheduler_worker.prepare_states_batch_async(specs) - finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) + request_ids = [spec.request_id for spec in specs] + for spec in specs: + self._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_debug", + backend="scheduler_debug", + media_type="wav", + response_streaming=False, + meta={ + "text_len": len(spec.text), + "prompt_text_len": len(spec.prompt_text), + "text_lang": spec.text_lang, + "prompt_lang": spec.prompt_lang, + "ref_audio_path": str(spec.ref_audio_path), + "ready_step": int(spec.ready_step), + }, + ) + self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) + self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) + prepare_started_at = time.perf_counter() + try: + states = await self.scheduler_worker.prepare_states_batch_async(specs) + except Exception as exc: + for request_id in request_ids: + self._fail_request_state(request_id, str(exc)) + raise + prepare_finished_at = time.perf_counter() + prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) + for state in states: + self._update_request_state( + state.request_id, + EngineStatus.ACTIVE_DECODE, + { + "prepare_profile": dict(state.prepare_profile), + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + }, + ) + decode_started_at = time.perf_counter() + try: + finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) + except Exception as exc: + for request_id in request_ids: + self._fail_request_state(request_id, str(exc)) + raise + decode_finished_at = time.perf_counter() + decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) + request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) + finished_map = {item.request_id: item for item in finished} + request_profiles: List[Dict[str, Any]] = [] + for state in states: + item = finished_map.get(state.request_id) + if item is None: + self._fail_request_state(state.request_id, "scheduler_debug finished without result") + continue + request_profile = self._build_scheduler_debug_request_profile( + state=state, + item=item, + batch_request_count=len(states), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + batch_request_total_ms=request_total_ms, + ) + request_profiles.append( + { + "request_id": state.request_id, + "profile": dict(request_profile), + } + ) + self._complete_request_state( + state.request_id, + dict(request_profile), + ) return SchedulerDebugExecution( payload={ "message": "success", "request_count": len(states), "max_steps": int(max_steps), + "batch_profile": self._build_scheduler_debug_batch_profile( + request_count=len(states), + max_steps=int(max_steps), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + request_total_ms=request_total_ms, + finished_items=finished, + ), "requests": self.summarize_scheduler_states(states), "finished": self.summarize_scheduler_finished(finished), + "request_profiles": request_profiles, + "request_traces": self._collect_request_summaries(request_ids), } ) async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: request_start = time.perf_counter() prepare_start = request_start - spec = self.build_scheduler_submit_spec(payload) + normalized = self._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), + ) + spec = self.build_scheduler_submit_spec(normalized) + deadline_ts = None + timeout_sec = normalized.timeout_sec + if timeout_sec is not None: + try: + deadline_ts = request_start + float(timeout_sec) + except Exception: + deadline_ts = None + self._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_submit", + backend="scheduler_v1", + media_type=normalized.media_type, + response_streaming=False, + deadline_ts=deadline_ts, + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) spec_ready_at = time.perf_counter() prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) - state, prepare_exec_started_at, prepare_exec_finished_at = await self.scheduler_worker.prepare_state_profiled_async( - spec, - spec_ready_at, - ) + self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = await self.scheduler_worker.prepare_state_profiled_async( + spec, + spec_ready_at, + ) + except Exception as exc: + self._fail_request_state(spec.request_id, str(exc)) + raise prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) @@ -1060,25 +2484,41 @@ class UnifiedTTSEngine: prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) + self._update_request_state( + spec.request_id, + EngineStatus.READY_FOR_PREFILL, + { + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": prepare_profile, + }, + ) api_after_prepare_start = time.perf_counter() loop = asyncio.get_running_loop() done_future = loop.create_future() - job = self.scheduler_worker.submit( + job = await self.scheduler_worker.submit_async( state=state, - speed_factor=float(payload["speed_factor"]), - sample_steps=int(payload["sample_steps"]), - media_type=str(payload["media_type"]), + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, prepare_wall_ms=prepare_wall_ms, prepare_profile_total_ms=prepare_profile_total_ms, done_loop=loop, done_future=done_future, + engine_request_id=spec.request_id, + timeout_sec=normalized.timeout_sec, ) api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) - await asyncio.wait_for(done_future, timeout=float(payload.get("timeout_sec", 30.0))) + try: + await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) + except Exception as exc: + self._fail_request_state(spec.request_id, str(exc)) + raise wait_return_at = time.perf_counter() if job.error is not None: raise RuntimeError(job.error) if job.audio_data is None or job.sample_rate is None or job.result is None: + self._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") raise RuntimeError(f"{job.request_id} finished without audio result") pack_start = time.perf_counter() audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() @@ -1087,95 +2527,37 @@ class UnifiedTTSEngine: api_wait_result_ms = 0.0 if job.result_ready_time is not None: api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) - worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 - headers = { - "X-Request-Id": job.request_id, - "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", - "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", - "X-Queue-Wait-Ms": f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000", - "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", - "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", - "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", - "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", - "X-Prepare-Admission-Wait-Ms": ( - f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" - if job.result is not None - else "0.000" - ), - "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", - "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", - "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", - "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", - "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", - "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", - "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", - "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", - "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", - "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", - "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", - "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", - "X-Pack-Ms": f"{pack_ms:.3f}", - "X-Worker-Total-Ms": f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000", - "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", - "X-Decode-Steps": str(int(job.result["decode_steps"])) if job.result is not None else "0", - "X-Sample-Rate": str(int(job.sample_rate)), - } - prepare_profile = job.result.get("prepare_profile", {}) if job.result is not None else {} - if job.result is not None: - headers.update( - { - "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", - "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", - "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get('text_cpu_parallel_workers', 0.0))), - "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", - "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", - "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", - "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", - "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", - "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", - "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get('worker_prepare_inflight_on_enter', 0.0))), - "X-Prepare-Inflight-Peak": str(int(prepare_profile.get('worker_prepare_peak_inflight', 0.0))), - } - ) response_ready_at = time.perf_counter() response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - request_other_ms = max( - 0.0, - request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + submit_profile = self._build_scheduler_submit_profile( + backend="scheduler_v1", + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=len(audio_data), + sample_rate=int(job.sample_rate), + prepare_spec_build_ms=prepare_spec_build_ms, + prepare_wall_ms=prepare_wall_ms, + prepare_executor_queue_ms=prepare_executor_queue_ms, + prepare_executor_run_ms=prepare_executor_run_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + prepare_profile_wall_ms=prepare_profile_wall_ms, + prepare_other_ms=prepare_other_ms, + api_after_prepare_ms=api_after_prepare_ms, + api_wait_result_ms=api_wait_result_ms, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + worker_profile=dict(job.result or {}), + ) + headers = self._build_scheduler_submit_headers( + request_id=job.request_id, + media_type=job.media_type, + sample_rate=int(job.sample_rate), + profile=submit_profile, + ) + self._merge_request_state_profile( + spec.request_id, + dict(submit_profile, response_headers_emitted=True), ) - headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" - headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" - headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) def get_scheduler_state(self) -> dict: @@ -1185,6 +2567,7 @@ class UnifiedTTSEngine: model_state = self.model_registry.snapshot() default_ref = self.reference_registry.get_default() scheduler_state = self.get_scheduler_state() + request_registry = self._snapshot_request_registry() return { "message": "success", "default_reference": { @@ -1200,6 +2583,8 @@ class UnifiedTTSEngine: "updated_at": model_state.updated_at, }, "worker_state": scheduler_state, + "request_registry": request_registry, + "stage_summary": self._build_stage_summary(request_registry, scheduler_state), } def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: diff --git a/api_v2.py b/api_v2.py index 21be1a10..35b70c8e 100644 --- a/api_v2.py +++ b/api_v2.py @@ -221,7 +221,7 @@ async def tts_handle(req: dict): """ try: - result = tts_engine.run_direct_tts(req) + result = await tts_engine.run_direct_tts_async(req) if result.streaming: return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") diff --git a/api_v3.py b/api_v3.py index 5c97995f..1a6457ec 100644 --- a/api_v3.py +++ b/api_v3.py @@ -290,7 +290,7 @@ async def tts_handle(req: dict): """ try: - result = tts_engine.run_direct_tts(req) + result = await tts_engine.run_direct_tts_async(req) if result.streaming: return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") From 06d6b67f73c8dc980ab7ce91f45eed5d1e2cb470 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 05:29:30 +0800 Subject: [PATCH 12/24] Add PreparedCpuStage data class and refactor prepare_cpu_stage_profiled_async method in PrepareCoordinator for improved CPU profiling. Introduce prepare_gpu_stage_profiled_async method to streamline GPU stage preparation using the new data class, enhancing overall performance and maintainability. --- .../TTS_infer_pack/prepare_coordinator.py | 100 +- GPT_SoVITS/TTS_infer_pack/unified_engine.py | 3574 +++++++++++++---- 2 files changed, 2955 insertions(+), 719 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 1fdf95c5..306b1b45 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -33,6 +33,20 @@ class ProfiledResult: return max(0.0, (self.finished_at - self.started_at) * 1000.0) +@dataclass +class PreparedCpuStage: + spec: SchedulerRequestSpec + prepare_submit_at: float + prepare_start: float + prompt_text: str + text: str + prepare_admission_wait_ms: float + current_inflight: int + peak_inflight: int + prompt_cpu_profiled: ProfiledResult + target_cpu_profiled: ProfiledResult + + class PrepareCoordinator: def __init__(self, tts: Any): self.tts = tts @@ -216,11 +230,16 @@ class PrepareCoordinator: async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) - async def prepare_state_profiled_async( + def _release_split_stage_slot(self) -> None: + self._mark_leave() + if self._inflight_semaphore is not None: + self._inflight_semaphore.release() + + async def prepare_cpu_stage_profiled_async( self, spec: SchedulerRequestSpec, prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: + ) -> PreparedCpuStage: admission_start = time.perf_counter() if self._inflight_semaphore is not None: await self._inflight_semaphore.acquire() @@ -230,17 +249,38 @@ class PrepareCoordinator: prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) text = spec.text.strip("\n") try: - text_pair_start = time.perf_counter() prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang)) target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang)) - ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(spec.ref_audio_path))) prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task) + return PreparedCpuStage( + spec=spec, + prepare_submit_at=float(prepare_submit_at), + prepare_start=float(prepare_start), + prompt_text=prompt_text, + text=text, + prepare_admission_wait_ms=float(prepare_admission_wait_ms), + current_inflight=int(current_inflight), + peak_inflight=int(peak_inflight), + prompt_cpu_profiled=prompt_cpu_profiled, + target_cpu_profiled=target_cpu_profiled, + ) + except Exception: + self._release_split_stage_slot() + raise + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + text_pair_start = time.perf_counter() + ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) text_feature_pair_task = asyncio.create_task( self._run_text_feature_pair_stage( - prompt_cpu_profiled.result, - target_cpu_profiled.result, - prompt_cpu_profiled.run_ms, - target_cpu_profiled.run_ms, + cpu_stage.prompt_cpu_profiled.result, + cpu_stage.target_cpu_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, ) ) (prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather( @@ -250,18 +290,18 @@ class PrepareCoordinator: text_pair_end = time.perf_counter() state = build_request_state_from_parts( tts=self.tts, - spec=spec, - prompt_text=prompt_text, - text=text, + spec=cpu_stage.spec, + prompt_text=cpu_stage.prompt_text, + text=cpu_stage.text, prompt_result=prompt_feature_profiled.result, target_result=target_feature_profiled.result, ref_audio_bundle=ref_audio_profiled.result, - prepare_start=prepare_start, - prepare_sync_start=prepare_start, + prepare_start=cpu_stage.prepare_start, + prepare_sync_start=cpu_stage.prepare_start, profile_overrides={ - "executor_queue_ms": max(0.0, (prepare_start - prepare_submit_at) * 1000.0), - "prepare_admission_wait_ms": prepare_admission_wait_ms, - "executor_run_wall_ms": max(0.0, (time.perf_counter() - prepare_start) * 1000.0), + "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, + "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), "prompt_text_parallel_future_wait_ms": 0.0, "prompt_text_parallel_future_executor_queue_ms": 0.0, @@ -269,26 +309,32 @@ class PrepareCoordinator: "prompt_text_parallel_future_finish_after_submit_ms": 0.0, "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, - "prompt_text_cpu_queue_ms": prompt_cpu_profiled.queue_ms, - "prompt_text_cpu_run_ms": prompt_cpu_profiled.run_ms, + "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, - "text_cpu_queue_ms": target_cpu_profiled.queue_ms, - "text_cpu_run_ms": target_cpu_profiled.run_ms, + "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, + "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, "text_feature_queue_ms": target_feature_profiled.queue_ms, "text_feature_run_ms": target_feature_profiled.run_ms, "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, "ref_audio_task_run_ms": ref_audio_profiled.run_ms, - "worker_prepare_inflight_on_enter": float(current_inflight), - "worker_prepare_peak_inflight": float(peak_inflight), + "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), + "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), }, ) prepare_exec_finished_at = time.perf_counter() state.prepare_profile["executor_run_wall_ms"] = max( - 0.0, (prepare_exec_finished_at - prepare_start) * 1000.0 + 0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0 ) - return state, prepare_start, prepare_exec_finished_at + return state, cpu_stage.prepare_start, prepare_exec_finished_at finally: - self._mark_leave() - if self._inflight_semaphore is not None: - self._inflight_semaphore.release() + self._release_split_stage_slot() + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + return await self.prepare_gpu_stage_profiled_async(cpu_stage) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index aed7b146..9b56199a 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -20,7 +20,7 @@ import soundfile as sf import torch from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( SchedulerRequestSpec, T2SActiveBatch, @@ -235,6 +235,46 @@ class SchedulerSubmitExecution: headers: Dict[str, str] +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + class EngineStatus: NEW = "NEW" QUEUED = "QUEUED" @@ -286,6 +326,146 @@ class EngineRequestState: } +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + requested = set(request_ids) + results: List[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + @dataclass class SchedulerPendingJob: request_id: str @@ -298,6 +478,8 @@ class SchedulerPendingJob: sample_steps: int media_type: str admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 prepare_wall_ms: float = 0.0 prepare_profile_total_ms: float = 0.0 first_schedule_time: float | None = None @@ -316,6 +498,584 @@ class SchedulerPendingJob: engine_request_id: str | None = None +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage == "prepare": + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + return "prepare", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + return "prepare", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int( + active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) + ), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + @dataclass class SchedulerFinalizeTask: request_id: str @@ -323,182 +1083,273 @@ class SchedulerFinalizeTask: enqueued_time: float +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + queue_wait_ms: float = 0.0 + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + @dataclass class RuntimeStateCallbacks: update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None -class UnifiedSchedulerWorker: +class WorkerPrepareExecutor: def __init__( self, tts: TTS, - max_steps: int = 1500, - micro_batch_wait_ms: int = 5, - runtime_callbacks: RuntimeStateCallbacks | None = None, - ): - self.tts = tts - self.max_steps = int(max_steps) - self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.prepare_coordinator = PrepareCoordinator(tts) - self.condition = threading.Condition() - self.prepare_inflight = 0 - self.prepare_peak_inflight = 0 - self.finalize_condition = threading.Condition() - self.finalize_pending_tasks: Deque[SchedulerFinalizeTask] = deque() - self.finalize_pending_peak = 0 - self.finalize_inflight = 0 - self.finalize_inflight_peak = 0 - self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) - self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - self.job_map: Dict[str, SchedulerPendingJob] = {} - self.total_finished = 0 - self.total_submitted = 0 - self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) - self.worker_thread.start() - self.finalize_threads = [ - threading.Thread( - target=self._run_finalize_loop, - name=f"unified-t2s-finalize-{worker_index}", - daemon=True, - ) - for worker_index in range(self.finalize_workers) - ] - for finalize_thread in self.finalize_threads: - finalize_thread.start() + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change - def _current_decode_backlog_locked(self) -> int: - running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) - return int(len(self.pending_jobs) + running_requests) + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass - def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: - decode_backlog = self._current_decode_backlog_locked() - finalize_pending = int(len(self.finalize_pending_tasks)) - prepare_inflight = int(self.prepare_coordinator.snapshot()["inflight"]) - blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max - blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max - return ( - not blocked_decode and not blocked_finalize, - { - "decode_backlog": decode_backlog, - "finalize_pending": finalize_pending, - "prepare_inflight": prepare_inflight, - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - }, - ) + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) - def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - last_snapshot: Dict[str, int] = {} - while True: - with self.condition: - allowed, snapshot = self._can_accept_submit_locked() - last_snapshot = snapshot - if allowed: - return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot - if deadline is not None and time.perf_counter() >= deadline: - raise TimeoutError( - "scheduler submit admission timeout " - f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" - ) - self.condition.wait(timeout=self.micro_batch_wait_s) + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) - async def submit_async( + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - ) -> SchedulerPendingJob: - return await asyncio.to_thread( - self.submit, - state, - speed_factor, - sample_steps, - media_type, - prepare_wall_ms, - prepare_profile_total_ms, - done_loop, - done_future, - engine_request_id, - timeout_sec, - ) + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() - def snapshot(self) -> dict: + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: with self.condition: - finalize_pending = len(self.finalize_pending_tasks) - prepare_state = self.prepare_coordinator.snapshot() - active_batch = self.active_batch - active_batch_summary = None - if active_batch is not None: - active_batch_summary = { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": ( - int(active_batch.step_indices.max().item()) - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 - else 0 - ), - } return { - "pending_jobs": len(self.pending_jobs), - "running_requests": 0 if active_batch is None else len(active_batch.request_ids), - "prepare_inflight": prepare_state["inflight"], - "prepare_peak_inflight": prepare_state["peak_inflight"], - "prepare_max_inflight": prepare_state.get("max_inflight", 0), - "prepare_state": dict(prepare_state), - "finalize_pending": finalize_pending, - "finalize_pending_peak": self.finalize_pending_peak, - "finalize_inflight": self.finalize_inflight, - "finalize_inflight_peak": self.finalize_inflight_peak, - "finalize_workers": self.finalize_workers, - "finalize_mode": self.finalize_mode, - "finalize_batch_max_items": self.finalize_batch_max_items, - "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, - "decode_backlog_max": self.decode_backlog_max, - "finalize_pending_max": self.finalize_pending_max, - "active_batch": active_batch_summary, - "total_submitted": self.total_submitted, - "total_finished": self.total_finished, - "drained": self.is_drained(), + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), } - def is_drained(self) -> bool: + def is_idle(self) -> bool: with self.condition: - with self.finalize_condition: - return ( - self.active_batch is None - and not self.pending_jobs - and not self.job_map - and self.prepare_coordinator.snapshot()["inflight"] <= 0 - and self.finalize_inflight <= 0 - and not self.finalize_pending_tasks - ) + return self.inflight <= 0 and not self.pending_tasks - def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: - deadline = time.perf_counter() + max(0.0, timeout_sec) - while time.perf_counter() < deadline: - if self.is_drained(): - return True - time.sleep(poll_interval_sec) - return self.is_drained() + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks def _sync_device(self) -> None: try: @@ -510,201 +1361,6 @@ class UnifiedSchedulerWorker: except Exception: pass - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - ) -> SchedulerPendingJob: - admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - admission_wait_ms=float(admission_wait_ms), - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - engine_request_id=engine_request_id or state.request_id, - ) - with self.condition: - self.pending_jobs.append(job) - self.job_map[job.request_id] = job - self.total_submitted += 1 - self.condition.notify_all() - self._runtime_update( - job.engine_request_id, - EngineStatus.QUEUED, - { - "scheduler_request_id": job.request_id, - "decode_admission_wait_ms": float(admission_wait_ms), - "admission_snapshot": dict(admission_snapshot), - }, - ) - with self.finalize_condition: - self.finalize_condition.notify_all() - return job - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - with self.condition: - self.prepare_inflight += 1 - self.prepare_peak_inflight = max(self.prepare_peak_inflight, self.prepare_inflight) - try: - return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - finally: - with self.condition: - self.prepare_inflight = max(0, self.prepare_inflight - 1) - self.condition.notify_all() - with self.finalize_condition: - self.finalize_condition.notify_all() - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - results = await asyncio.gather( - *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] - ) - return [state for state, _, _ in results] - - def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in pending_jobs: - tracked_job = self.job_map.get(job.request_id) - if tracked_job is None: - continue - tracked_job.first_schedule_time = float(started_at) - self._runtime_update( - tracked_job.engine_request_id, - EngineStatus.GPU_PREPARING, - {"scheduler_request_id": tracked_job.request_id, "prefill_started_at": float(started_at)}, - ) - - def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.prefill_ms += delta_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - activate_request_ids: List[str] = [] - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is not None: - job.finalize_wait_ms += float(delta_ms) - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - with self.finalize_condition: - for item in items: - job = self.job_map.get(item.request_id) - if job is not None: - self._runtime_update( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - }, - ) - self.finalize_pending_tasks.append( - SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) - ) - self.finalize_pending_peak = max(self.finalize_pending_peak, len(self.finalize_pending_tasks)) - self.finalize_condition.notify_all() - - def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: - with self.finalize_condition: - while not self.finalize_pending_tasks: - self.finalize_condition.wait() - selected_tasks = [self.finalize_pending_tasks.popleft()] - if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: - self.finalize_inflight += len(selected_tasks) - self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) - return selected_tasks - batch_deadline = time.perf_counter() + self.finalize_batch_wait_s - while len(selected_tasks) < self.finalize_batch_max_items: - if not self.finalize_pending_tasks: - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.finalize_condition.wait(timeout=remaining) - continue - first_task = selected_tasks[0] - matched_index = None - for index, task in enumerate(self.finalize_pending_tasks): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is not None: - selected_tasks.append(self.finalize_pending_tasks[matched_index]) - del self.finalize_pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.finalize_condition.wait(timeout=remaining) - self.finalize_inflight += len(selected_tasks) - self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) - with self.condition: - self.condition.notify_all() - return selected_tasks - - def _finalize_task_done(self, count: int) -> None: - with self.finalize_condition: - self.finalize_inflight = max(0, self.finalize_inflight - count) - self.finalize_condition.notify_all() - with self.condition: - self.condition.notify_all() - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: audio_fragment = self.tts.synthesize_audio_request_local( semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), @@ -772,93 +1428,36 @@ class UnifiedSchedulerWorker: ) return results - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - finished_at = time.perf_counter() - with self.condition: - if self.job_map.get(item.request_id) is not job: - return - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - job.result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "decode_admission_wait_ms": float(job.admission_wait_ms), - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.done_event.set() - self._notify_done_future(job) - self.job_map.pop(item.request_id, None) - self.total_finished += 1 - self.condition.notify_all() - self._runtime_complete( - job.engine_request_id, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "sample_rate": int(sample_rate), - "worker_profile": dict(job.result or {}), - }, - ) + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results - def _finalize_error(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_map.get(request_id) - if job is None: - continue - job.error = error - job.done_event.set() - self._notify_done_future(job) - self.job_map.pop(request_id, None) - self.total_finished += 1 - self._runtime_fail(job.engine_request_id, error) - self.condition.notify_all() + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() @staticmethod def _resolve_done_future(job: SchedulerPendingJob) -> None: future = job.done_future if future is None or future.done(): return - future.set_result(True) + future.set_result(job) - def _notify_done_future(self, job: SchedulerPendingJob) -> None: + def notify_done_future(self, job: SchedulerPendingJob) -> None: if job.done_loop is None or job.done_future is None: return try: @@ -866,22 +1465,409 @@ class UnifiedSchedulerWorker: except RuntimeError: pass - def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.update is None: - return - self.runtime_callbacks.update(request_id, status, extra) - - def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: if request_id is None or self.runtime_callbacks.complete is None: return self.runtime_callbacks.complete(request_id, extra) - def _runtime_fail(self, request_id: str | None, error: str) -> None: + def runtime_fail(self, request_id: str | None, error: str) -> None: if request_id is None or self.runtime_callbacks.fail is None: return self.runtime_callbacks.fail(request_id, error) - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: with self.condition: if not self.pending_jobs and self.active_batch is None: self.condition.wait(timeout=self.micro_batch_wait_s) @@ -893,104 +1879,781 @@ class UnifiedSchedulerWorker: self.pending_jobs.clear() return pending - def _run_finalize_loop(self) -> None: + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: while True: - tasks = self._take_finalize_task_batch() - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: with self.condition: - for task in tasks: - job = self.job_map.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - continue - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - for job, item in jobs_and_items: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) + + +class UnifiedSchedulerWorker: + def __init__( + self, + tts: TTS, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + runtime_callbacks: RuntimeStateCallbacks | None = None, + external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ): + self.tts = tts + self.max_steps = int(max_steps) + self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.condition = threading.Condition() + self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks) + self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps) + self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s) + self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks) + self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change) + self.finalize_executor = WorkerFinalizeExecutor( + tts, + on_state_change=self._notify_worker_state_change, + external_submit=external_finalize_submit, + ) + self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) + self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) + self.engine_decode_control_enabled = ( + str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} + ) + self.job_registry = SchedulerJobRegistry(self.condition) + self.worker_thread: threading.Thread | None = None + if not self.engine_decode_control_enabled: + self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) + self.worker_thread.start() + self.finalize_threads = [] + if external_finalize_submit is None: + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"unified-t2s-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_executor.get_worker_count()) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() + + def _notify_worker_state_change(self) -> None: + with self.condition: + self.condition.notify_all() + + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + last_snapshot: Dict[str, int] = {} + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + last_snapshot = snapshot + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: self._runtime_update( job.engine_request_id, - EngineStatus.FINALIZING, + EngineStatus.READY_FOR_FINALIZE, { "finish_reason": item.finish_reason, "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), }, ) - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_map.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self._finalize_task_done(len(tasks)) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() + + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) def _run_loop(self) -> None: - while True: - wait_for_batch = self.active_batch is None - pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) - - if pending_jobs: - try: - self._sync_device() - prefill_start = time.perf_counter() - self._mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - self._add_prefill_time([job.request_id for job in pending_jobs], time.perf_counter() - prefill_start) - self._enqueue_finalize_finished(admitted_finished) - merge_start = time.perf_counter() - self.active_batch = merge_active_batches( - self.tts.t2s_model.model, - self.active_batch, - admitted_active_batch, - ) - self._add_merge_time( - [] if self.active_batch is None else list(self.active_batch.request_ids), - time.perf_counter() - merge_start, - ) - except Exception as exc: - self._finalize_error([job.request_id for job in pending_jobs], str(exc)) - - if self.active_batch is not None: - active_request_ids: List[str] = [] - try: - active_request_ids = [state.request_id for state in self.active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - self.active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - self.active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) - self._enqueue_finalize_finished(step_finished) - except Exception as exc: - self._finalize_error(active_request_ids, str(exc)) - self.active_batch = None - continue - - if not pending_jobs: - time.sleep(self.micro_batch_wait_s) + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) def set_scheduler_seed(seed: int): @@ -1088,6 +2751,27 @@ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=3 class UnifiedTTSEngine: + @staticmethod + def _env_flag(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return bool(default) + return str(value).strip().lower() not in {"0", "false", "no", "off", ""} + + @staticmethod + def _env_int(name: str, default: int) -> int: + value = os.environ.get(name) + if value in [None, ""]: + return int(default) + return int(value) + + @staticmethod + def _env_float(name: str, default: float) -> float: + value = os.environ.get(name) + if value in [None, ""]: + return float(default) + return float(value) + def __init__( self, tts: TTS, @@ -1104,10 +2788,10 @@ class UnifiedTTSEngine: t2s_weights_path=str(self.tts.configs.t2s_weights_path), vits_weights_path=str(self.tts.configs.vits_weights_path), ) - self.request_registry_lock = threading.Lock() - self.active_requests: Dict[str, EngineRequestState] = {} - self.recent_requests: Deque[EngineRequestState] = deque() - self.recent_request_limit = max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + self.request_registry = EngineRequestRegistry( + recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + ) + self.engine_job_registry = SchedulerJobRegistry(threading.Lock()) self.scheduler_worker = UnifiedSchedulerWorker( tts, max_steps=max_steps, @@ -1116,10 +2800,72 @@ class UnifiedTTSEngine: update=self._update_request_state, complete=self._complete_request_state, fail=self._fail_request_state, + decode_runtime_update=self._update_engine_decode_runtime_state, ), + external_finalize_submit=self._enqueue_worker_finished_for_finalize, ) self.direct_tts_lock = threading.RLock() self.management_lock = threading.RLock() + worker_capacity_limits = self.scheduler_worker.get_capacity_limits() + prepare_max_inflight = int(self.scheduler_worker.get_prepare_max_inflight()) + self.engine_policy_config = EnginePolicyConfig( + enabled=self._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True), + poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_backlog_soft_max=max( + 0, + self._env_int( + "GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX", + int(worker_capacity_limits["decode_backlog_max"]), + ), + ), + finalize_pending_soft_max=max( + 0, + self._env_int( + "GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX", + int(worker_capacity_limits["finalize_pending_max"]), + ), + ), + prepare_inflight_soft_max=max( + 0, + self._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight), + ), + active_decode_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)), + ready_for_prefill_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)), + active_request_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)), + ) + self.engine_arbiter_config = EngineArbiterConfig( + poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_burst=max(1, self._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)), + prepare_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)), + finalize_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)), + ) + self.engine_decode_runtime_owner = EngineDecodeRuntimeOwner( + get_decode_runtime_counters=self.scheduler_worker.get_decode_runtime_counters, + get_micro_batch_wait_s=self.scheduler_worker.get_micro_batch_wait_s, + ) + self.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + self.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + self.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") + self.engine_dispatch_last_snapshot: Dict[str, Any] = {} + self.engine_policy_arbiter = EnginePolicyArbiterController( + policy_config=self.engine_policy_config, + arbiter_config=self.engine_arbiter_config, + snapshot_request_registry=self._snapshot_request_registry, + get_worker_state=self.get_scheduler_state, + snapshot_prepare_state=self._snapshot_engine_prepare_state, + snapshot_finalize_state=self._snapshot_engine_finalize_state, + snapshot_dispatch_state=self._snapshot_engine_dispatch_state, + snapshot_decode_runtime_state=self._snapshot_engine_decode_runtime_state, + snapshot_job_registry=self._snapshot_engine_job_registry, + peek_queue_age_ms=self._peek_queue_age_ms, + merge_request_state_profile=self._merge_request_state_profile, + ) + self.engine_arbiter_thread = threading.Thread( + target=self._run_engine_arbiter_loop, + name="unified-engine-arbiter", + daemon=True, + ) + self.engine_arbiter_thread.start() def _register_request_state( self, @@ -1131,27 +2877,15 @@ class UnifiedTTSEngine: deadline_ts: float | None = None, meta: Optional[Dict[str, Any]] = None, ) -> EngineRequestState: - now = time.perf_counter() - state = EngineRequestState( + return self.request_registry.register( request_id=request_id, api_mode=api_mode, backend=backend, media_type=media_type, - response_streaming=bool(response_streaming), - submit_ts=now, + response_streaming=response_streaming, deadline_ts=deadline_ts, - updated_ts=now, - meta=dict(meta or {}), - lifecycle_timestamps={EngineStatus.NEW: now}, + meta=meta, ) - with self.request_registry_lock: - self.active_requests[request_id] = state - return state - - def _move_to_recent_locked(self, state: EngineRequestState) -> None: - self.recent_requests.appendleft(state) - while len(self.recent_requests) > self.recent_request_limit: - self.recent_requests.pop() def _update_request_state( self, @@ -1159,95 +2893,501 @@ class UnifiedTTSEngine: status: str, extra: Optional[Dict[str, Any]] = None, ) -> None: - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.get(request_id) - if state is None: - return - state.status = status - state.updated_ts = now - state.lifecycle_timestamps[status] = now - if extra: - backend = extra.pop("backend", None) - if backend is not None: - state.backend = str(backend) - finish_reason = extra.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = extra.pop("error", None) - if error is not None: - state.error = str(error) - state.profile.update(extra) + self.request_registry.update(request_id, status, extra) def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - if not extra: + self.request_registry.merge_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: return - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.get(request_id) - if state is None: - for recent_state in self.recent_requests: - if recent_state.request_id == request_id: - state = recent_state - break - if state is None: - return - state.updated_ts = now - backend = extra.get("backend") - if backend is not None: - state.backend = str(backend) - finish_reason = extra.get("finish_reason") - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = extra.get("error") - if error is not None: - state.error = str(error) - merged = dict(extra) - merged.pop("backend", None) - merged.pop("finish_reason", None) - merged.pop("error", None) - state.profile.update(merged) + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self._enqueue_worker_finished_for_finalize(tasks) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_decode_runtime_owner.enqueue_pending_job(job) + self._notify_engine_arbiter() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return self.engine_prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + elif queue_name == "finalize": + return self.engine_finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + elif queue_name == "decode_runtime_pending": + return self.engine_decode_runtime_owner.pending_age_ms() + else: + return self.engine_dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def _engine_has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.engine_decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self._snapshot_engine_decode_runtime_state().get("active_request_count", 0) > 0: + return True + if self.engine_prepare_queue_owner.has_items(): + return True + if self.engine_finalize_queue_owner.has_items(): + return True + return self.engine_dispatch_queue_owner.has_items() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self._update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.engine_prepare_queue_owner.enqueue(task) + self._notify_engine_arbiter() + state, prepare_exec_started_at, prepare_exec_finished_at = await done_future + return state, prepare_exec_started_at, prepare_exec_finished_at + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self._get_engine_job(task.request_id) + if job is not None: + self._update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.engine_finalize_queue_owner.enqueue_many(tasks) + self._notify_engine_arbiter() + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.engine_finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.engine_dispatch_queue_owner.enqueue(task) + self._notify_engine_arbiter() + self._merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int(self._snapshot_engine_dispatch_state()["waiting_count"]), + }, + ) + return task + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state + + def _run_engine_prepare_once(self) -> bool: + task = self.engine_prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self._merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.engine_prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self._fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True + + def _run_engine_finalize_once(self) -> bool: + tasks = self._take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self._get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self._get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self._update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.engine_finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.engine_dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self._register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self._enqueue_engine_decode_pending_job(worker_job) + self.engine_dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self._fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True + + def _run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self._snapshot_engine_decode_runtime_state() + pending_jobs = self._take_engine_decode_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.engine_decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self._fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self._add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self._add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self._enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self._fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self._add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self._enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.engine_decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self._refresh_engine_decode_runtime_state("engine_decode_cycle") + return bool(result.get("executed", False)) + + def _run_engine_arbiter_loop(self) -> None: + while True: + if not self._engine_has_pending_work(): + self._mark_arbiter_tick(stage="idle", reason="no_pending_work", policy_allowed=True) + self.engine_policy_arbiter.wait() + continue + stage, reason, policy_snapshot, worker_state = self._select_engine_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self._run_engine_prepare_once() + elif stage == "finalize": + executed = self._run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self._run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self._run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick(stage="idle", reason=f"{stage}_not_ready", policy_allowed=policy_allowed) + self.engine_policy_arbiter.wait() + continue + self._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.COMPLETED - state.updated_ts = now - state.lifecycle_timestamps[EngineStatus.COMPLETED] = now - if extra: - finish_reason = extra.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - state.profile.update(extra) - self._move_to_recent_locked(state) + self.request_registry.complete(request_id, extra) def _fail_request_state(self, request_id: str, error: str) -> None: - now = time.perf_counter() - with self.request_registry_lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.FAILED - state.updated_ts = now - state.error = str(error) - state.lifecycle_timestamps[EngineStatus.FAILED] = now - self._move_to_recent_locked(state) + self.request_registry.fail(request_id, error) def _snapshot_request_registry(self) -> Dict[str, Any]: - with self.request_registry_lock: - active = [state.to_summary() for state in self.active_requests.values()] - recent = [state.to_summary() for state in list(self.recent_requests)] - active.sort(key=lambda item: item["submit_ts"]) - return { - "active_count": len(active), - "recent_count": len(recent), - "recent_limit": self.recent_request_limit, - "active_requests": active, - "recent_requests": recent, - } + return self.request_registry.snapshot() @staticmethod def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: @@ -1258,41 +3398,56 @@ class UnifiedTTSEngine: except Exception: return None + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + def _build_stage_summary( self, request_registry: Dict[str, Any], worker_state: Dict[str, Any], ) -> Dict[str, Any]: - active_requests = list(request_registry.get("active_requests", [])) - status_counts: Dict[str, int] = {} - for item in active_requests: - status = str(item.get("status", "UNKNOWN")) - status_counts[status] = status_counts.get(status, 0) + 1 - + counters = self._build_stage_counters(request_registry, worker_state) bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) return { - "active_request_count": int(len(active_requests)), - "status_counts": status_counts, - "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), - "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), - "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), - "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), - "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), - "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), - "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), - "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), - "worker_pending_jobs": int(worker_state.get("pending_jobs", 0)), - "worker_decode_active_size": int(worker_state.get("running_requests", 0)), - "worker_prepare_inflight": int(worker_state.get("prepare_inflight", 0)), - "worker_finalize_pending": int(worker_state.get("finalize_pending", 0)), - "worker_finalize_inflight": int(worker_state.get("finalize_inflight", 0)), + **counters, + "engine_drained": bool(self._is_engine_drained()), "admission_config": { "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), }, + "engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state), + "engine_arbiter_state": self._snapshot_engine_arbiter_state(), + "engine_decode_runtime_state": self._snapshot_engine_decode_runtime_state(), + "engine_job_registry": self._snapshot_engine_job_registry(), + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": self._snapshot_engine_prepare_state(), + "engine_finalize_state": self._snapshot_engine_finalize_state(), + "engine_dispatcher_state": self._snapshot_engine_dispatch_state(), "active_batch": dict(worker_state.get("active_batch") or {}), "prepare_state": dict(worker_state.get("prepare_state") or {}), "bert_batch_worker_state": bert_worker_state, @@ -1301,21 +3456,10 @@ class UnifiedTTSEngine: } def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - requested = set(request_ids) - results: List[Dict[str, Any]] = [] - with self.request_registry_lock: - for state in self.active_requests.values(): - if state.request_id in requested: - results.append(state.to_summary()) - for state in self.recent_requests: - if state.request_id in requested and all(item["request_id"] != state.request_id for item in results): - results.append(state.to_summary()) - results.sort(key=lambda item: item["request_id"]) - return results + return self.request_registry.collect_summaries(request_ids) def _has_active_request(self, request_id: str) -> bool: - with self.request_registry_lock: - return request_id in self.active_requests + return self.request_registry.has_active(request_id) @staticmethod def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: @@ -1356,6 +3500,11 @@ class UnifiedTTSEngine: "text_len": len(str(segment_text)), "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), + "prepare_engine_gpu_queue_wait_ms": float( + dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) + ), + "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), + "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), @@ -1391,6 +3540,8 @@ class UnifiedTTSEngine: request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") + engine_policy_wait_ms = self._sum_profile_field(prepare_profiles, "engine_policy_wait_ms") + engine_dispatch_wait_ms = self._sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") @@ -1403,7 +3554,7 @@ class UnifiedTTSEngine: semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) request_other_ms = max( 0.0, - request_total_ms - prepare_wall_ms - worker_total_ms - pack_ms - response_overhead_ms, + request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, ) return { "backend": backend, @@ -1415,6 +3566,8 @@ class UnifiedTTSEngine: "prepare_ms": prepare_wall_ms, "prepare_wall_ms": prepare_wall_ms, "prepare_profile_total_ms": prepare_profile_total_ms, + "engine_policy_wait_ms": engine_policy_wait_ms, + "engine_dispatch_wait_ms": engine_dispatch_wait_ms, "decode_admission_wait_ms": decode_admission_wait_ms, "queue_wait_ms": queue_wait_ms, "prefill_ms": prefill_ms, @@ -1488,6 +3641,7 @@ class UnifiedTTSEngine: prepare_profile_total_ms: float, prepare_profile_wall_ms: float, prepare_other_ms: float, + engine_policy_wait_ms: float, api_after_prepare_ms: float, api_wait_result_ms: float, pack_ms: float, @@ -1498,7 +3652,13 @@ class UnifiedTTSEngine: request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) request_other_ms = max( 0.0, - request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + request_total_ms + - prepare_wall_ms + - engine_policy_wait_ms + - api_after_prepare_ms + - worker_total_ms + - api_wait_result_ms + - pack_ms, ) result = { "backend": backend, @@ -1513,6 +3673,7 @@ class UnifiedTTSEngine: "prepare_profile_total_ms": prepare_profile_total_ms, "prepare_profile_wall_ms": prepare_profile_wall_ms, "prepare_other_ms": prepare_other_ms, + "engine_policy_wait_ms": float(engine_policy_wait_ms), "api_after_prepare_ms": api_after_prepare_ms, "api_wait_result_ms": api_wait_result_ms, "pack_ms": pack_ms, @@ -1542,6 +3703,8 @@ class UnifiedTTSEngine: "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), + "X-Engine-Policy-Wait-Ms": self._format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), + "X-Engine-Dispatch-Wait-Ms": self._format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), @@ -1598,6 +3761,7 @@ class UnifiedTTSEngine: "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Engine-GPU-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), @@ -2000,14 +4164,22 @@ class UnifiedTTSEngine: prepared_items = await asyncio.gather( *[ - self.scheduler_worker.prepare_state_profiled_async(spec, time.perf_counter()) + self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=None, + ) for spec in segment_specs ] ) prepare_profiles: List[Dict[str, Any]] = [] - jobs: List[SchedulerPendingJob] = [] loop = asyncio.get_running_loop() done_futures: List[asyncio.Future] = [] + self._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, + ) for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) @@ -2021,39 +4193,38 @@ class UnifiedTTSEngine: ) done_future = loop.create_future() done_futures.append(done_future) - jobs.append( - await self.scheduler_worker.submit_async( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=None, - timeout_sec=normalized.timeout_sec, - ) + await self._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, ) - self._update_request_state( - request_id, - EngineStatus.READY_FOR_PREFILL, - { - "backend": "scheduler_v1_direct", - "backend_mode": "scheduler_v1_direct", - "segment_count": len(segment_specs), - "prepare_aggregate": self._aggregate_numeric_dicts( - [item["prepare_profile"] for item in prepare_profiles] - ), - }, - ) self._update_request_state( request_id, EngineStatus.ACTIVE_DECODE, {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, ) timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) - await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec) + jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) + for profile_item, job in zip(prepare_profiles, jobs): + profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) + profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) + self._merge_request_state_profile( + request_id, + { + "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), + "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), + "prepare_aggregate": self._aggregate_numeric_dicts( + [item["prepare_profile"] for item in prepare_profiles] + ), + }, + ) sample_rate: int | None = None audio_parts: List[np.ndarray] = [] @@ -2470,9 +4641,10 @@ class UnifiedTTSEngine: prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) try: - state, prepare_exec_started_at, prepare_exec_finished_at = await self.scheduler_worker.prepare_state_profiled_async( - spec, - spec_ready_at, + state, prepare_exec_started_at, prepare_exec_finished_at = await self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=spec_ready_at, + engine_request_id=spec.request_id, ) except Exception as exc: self._fail_request_state(spec.request_id, str(exc)) @@ -2496,7 +4668,7 @@ class UnifiedTTSEngine: api_after_prepare_start = time.perf_counter() loop = asyncio.get_running_loop() done_future = loop.create_future() - job = await self.scheduler_worker.submit_async( + await self._enqueue_prepared_state_for_dispatch( state=state, speed_factor=float(normalized.speed_factor), sample_steps=int(normalized.sample_steps), @@ -2510,7 +4682,7 @@ class UnifiedTTSEngine: ) api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) try: - await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) + job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) except Exception as exc: self._fail_request_state(spec.request_id, str(exc)) raise @@ -2542,6 +4714,7 @@ class UnifiedTTSEngine: prepare_profile_total_ms=prepare_profile_total_ms, prepare_profile_wall_ms=prepare_profile_wall_ms, prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), api_after_prepare_ms=api_after_prepare_ms, api_wait_result_ms=api_wait_result_ms, pack_ms=pack_ms, @@ -2568,6 +4741,14 @@ class UnifiedTTSEngine: default_ref = self.reference_registry.get_default() scheduler_state = self.get_scheduler_state() request_registry = self._snapshot_request_registry() + engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state) + engine_arbiter_state = self._snapshot_engine_arbiter_state() + engine_decode_runtime_state = self._snapshot_engine_decode_runtime_state() + engine_job_registry = self._snapshot_engine_job_registry() + engine_prepare_state = self._snapshot_engine_prepare_state() + engine_finalize_state = self._snapshot_engine_finalize_state() + engine_dispatcher_state = self._snapshot_engine_dispatch_state() + engine_drained = self._is_engine_drained() return { "message": "success", "default_reference": { @@ -2583,6 +4764,15 @@ class UnifiedTTSEngine: "updated_at": model_state.updated_at, }, "worker_state": scheduler_state, + "engine_policy": engine_policy, + "engine_arbiter_state": engine_arbiter_state, + "engine_decode_runtime_state": engine_decode_runtime_state, + "engine_job_registry": engine_job_registry, + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": engine_prepare_state, + "engine_finalize_state": engine_finalize_state, + "engine_dispatcher_state": engine_dispatcher_state, + "engine_drained": bool(engine_drained), "request_registry": request_registry, "stage_summary": self._build_stage_summary(request_registry, scheduler_state), } From d1ec7d9e5442124d1401884e33290fcc49fe6c7c Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 08:32:56 +0800 Subject: [PATCH 13/24] Add unified engine components and API for enhanced TTS processing Introduce multiple new modules including unified_engine_api, unified_engine_audio, unified_engine_bridge, unified_engine_builder, unified_engine_components, unified_engine_delegates, and unified_engine_runtime. These additions provide a comprehensive framework for managing TTS requests, audio packing, and engine state management, significantly improving the architecture and maintainability of the TTS system. The new structure supports asynchronous operations and enhances overall performance through better request handling and processing capabilities. --- GPT_SoVITS/TTS_infer_pack/unified_engine.py | 4797 +---------------- .../TTS_infer_pack/unified_engine_api.py | 1399 +++++ .../TTS_infer_pack/unified_engine_audio.py | 106 + .../TTS_infer_pack/unified_engine_bridge.py | 310 ++ .../TTS_infer_pack/unified_engine_builder.py | 179 + .../unified_engine_components.py | 1150 ++++ .../unified_engine_delegates.py | 446 ++ .../TTS_infer_pack/unified_engine_runtime.py | 198 + .../TTS_infer_pack/unified_engine_stage.py | 420 ++ .../TTS_infer_pack/unified_engine_worker.py | 1510 ++++++ 10 files changed, 5724 insertions(+), 4791 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_api.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_components.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index 9b56199a..24e1c98b 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -1,2756 +1,15 @@ from __future__ import annotations -import asyncio import os -import signal -import subprocess -import sys -import threading -import time -import uuid -import wave -from collections import deque -from dataclasses import dataclass, field -from io import BytesIO -from pathlib import Path -from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Sequence, Tuple, Union - -import numpy as np -import soundfile as sf -import torch +from typing import Sequence from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( - SchedulerRequestSpec, - T2SActiveBatch, - T2SFinishedItem, - T2SRequestState, - decode_one_step, - merge_active_batches, - run_prefill_active_batch, - run_scheduler_continuous, -) +from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks +from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates -@dataclass -class RuntimeControlCallbacks: - restart: Callable[[], None] | None = None - exit: Callable[[], None] | None = None - - -@dataclass -class DefaultReferenceState: - ref_audio_path: str | None = None - updated_at: float = 0.0 - - -class ReferenceRegistry: - def __init__(self) -> None: - self._lock = threading.Lock() - self._state = DefaultReferenceState() - - def set_default(self, ref_audio_path: str) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) - return self._state - - def clear(self) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState() - return self._state - - def get_default(self) -> DefaultReferenceState: - with self._lock: - return DefaultReferenceState( - ref_audio_path=self._state.ref_audio_path, - updated_at=self._state.updated_at, - ) - - -@dataclass -class ModelRegistryState: - t2s_weights_path: str - vits_weights_path: str - generation: int = 0 - t2s_generation: int = 0 - vits_generation: int = 0 - updated_at: float = field(default_factory=time.time) - - -class ModelRegistry: - def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: - self._lock = threading.Lock() - self._state = ModelRegistryState( - t2s_weights_path=str(t2s_weights_path), - vits_weights_path=str(vits_weights_path), - ) - - def snapshot(self) -> ModelRegistryState: - with self._lock: - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.t2s_weights_path = str(weights_path) - self._state.generation += 1 - self._state.t2s_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.vits_weights_path = str(weights_path) - self._state.generation += 1 - self._state.vits_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - -@dataclass -class DirectTTSExecution: - media_type: str - streaming: bool - audio_generator: Optional[Generator[bytes, None, None]] = None - audio_bytes: Optional[bytes] = None - request_id: Optional[str] = None - - -@dataclass -class NormalizedEngineRequest: - request_id: str - text: str - text_lang: str - ref_audio_path: str - prompt_lang: str - prompt_text: str = "" - aux_ref_audio_paths: List[str] | None = None - top_k: int = 15 - top_p: float = 1.0 - temperature: float = 1.0 - repetition_penalty: float = 1.35 - early_stop_num: int = -1 - ready_step: int = 0 - text_split_method: str = "cut5" - batch_size: int = 1 - batch_threshold: float = 0.75 - split_bucket: bool = False - speed_factor: float = 1.0 - fragment_interval: float = 0.3 - seed: int = -1 - media_type: str = "wav" - streaming_mode: bool | int = False - return_fragment: bool = False - fixed_length_chunk: bool = False - response_streaming: bool = False - parallel_infer: bool = False - sample_steps: int = 32 - super_sampling: bool = False - overlap_length: int = 2 - min_chunk_length: int = 16 - timeout_sec: float | None = None - - def to_payload(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "text": self.text, - "text_lang": self.text_lang, - "ref_audio_path": self.ref_audio_path, - "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, - "prompt_text": self.prompt_text, - "prompt_lang": self.prompt_lang, - "top_k": self.top_k, - "top_p": self.top_p, - "temperature": self.temperature, - "text_split_method": self.text_split_method, - "batch_size": self.batch_size, - "batch_threshold": self.batch_threshold, - "speed_factor": self.speed_factor, - "split_bucket": self.split_bucket, - "fragment_interval": self.fragment_interval, - "seed": self.seed, - "media_type": self.media_type, - "streaming_mode": self.streaming_mode, - "return_fragment": self.return_fragment, - "fixed_length_chunk": self.fixed_length_chunk, - "response_streaming": self.response_streaming, - "parallel_infer": self.parallel_infer, - "repetition_penalty": self.repetition_penalty, - "sample_steps": self.sample_steps, - "super_sampling": self.super_sampling, - "overlap_length": self.overlap_length, - "min_chunk_length": self.min_chunk_length, - "early_stop_num": self.early_stop_num, - "ready_step": self.ready_step, - "timeout_sec": self.timeout_sec, - } - - def to_scheduler_spec(self) -> SchedulerRequestSpec: - return SchedulerRequestSpec( - request_id=self.request_id, - ref_audio_path=Path(self.ref_audio_path), - prompt_text=self.prompt_text, - prompt_lang=self.prompt_lang, - text=self.text, - text_lang=self.text_lang, - top_k=self.top_k, - top_p=self.top_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - early_stop_num=self.early_stop_num, - ready_step=self.ready_step, - ) - - -@dataclass -class SchedulerDebugExecution: - payload: Dict[str, Any] - - -@dataclass -class SchedulerSubmitExecution: - audio_bytes: bytes - media_type: str - headers: Dict[str, str] - - -@dataclass -class EnginePolicyConfig: - enabled: bool = True - poll_wait_ms: float = 5.0 - decode_backlog_soft_max: int = 0 - finalize_pending_soft_max: int = 0 - prepare_inflight_soft_max: int = 0 - active_decode_soft_max: int = 0 - ready_for_prefill_soft_max: int = 0 - active_request_soft_max: int = 0 - - def to_dict(self) -> Dict[str, Any]: - return { - "enabled": bool(self.enabled), - "poll_wait_ms": float(self.poll_wait_ms), - "decode_backlog_soft_max": int(self.decode_backlog_soft_max), - "finalize_pending_soft_max": int(self.finalize_pending_soft_max), - "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), - "active_decode_soft_max": int(self.active_decode_soft_max), - "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), - "active_request_soft_max": int(self.active_request_soft_max), - } - - -@dataclass -class EngineArbiterConfig: - poll_wait_ms: float = 5.0 - decode_burst: int = 4 - prepare_aging_ms: float = 10.0 - finalize_aging_ms: float = 10.0 - - def to_dict(self) -> Dict[str, Any]: - return { - "poll_wait_ms": float(self.poll_wait_ms), - "decode_burst": int(self.decode_burst), - "prepare_aging_ms": float(self.prepare_aging_ms), - "finalize_aging_ms": float(self.finalize_aging_ms), - } - - -class EngineStatus: - NEW = "NEW" - QUEUED = "QUEUED" - VALIDATED = "VALIDATED" - CPU_PREPARING = "CPU_PREPARING" - GPU_PREPARING = "GPU_PREPARING" - READY_FOR_PREFILL = "READY_FOR_PREFILL" - ACTIVE_DECODE = "ACTIVE_DECODE" - READY_FOR_FINALIZE = "READY_FOR_FINALIZE" - FINALIZING = "FINALIZING" - STREAMING = "STREAMING" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - - -@dataclass -class EngineRequestState: - request_id: str - api_mode: str - backend: str - media_type: str - response_streaming: bool - submit_ts: float - deadline_ts: float | None = None - status: str = EngineStatus.NEW - updated_ts: float = 0.0 - error: str | None = None - finish_reason: str | None = None - meta: Dict[str, Any] = field(default_factory=dict) - profile: Dict[str, Any] = field(default_factory=dict) - lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) - - def to_summary(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "api_mode": self.api_mode, - "backend": self.backend, - "media_type": self.media_type, - "response_streaming": self.response_streaming, - "status": self.status, - "submit_ts": self.submit_ts, - "updated_ts": self.updated_ts, - "deadline_ts": self.deadline_ts, - "error": self.error, - "finish_reason": self.finish_reason, - "meta": dict(self.meta), - "profile": dict(self.profile), - "lifecycle_timestamps": dict(self.lifecycle_timestamps), - } - - -class EngineRequestRegistry: - def __init__(self, recent_limit: int) -> None: - self.lock = threading.Lock() - self.active_requests: Dict[str, EngineRequestState] = {} - self.recent_requests: Deque[EngineRequestState] = deque() - self.recent_limit = max(1, int(recent_limit)) - - def register( - self, - *, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - now = time.perf_counter() - state = EngineRequestState( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=bool(response_streaming), - submit_ts=now, - deadline_ts=deadline_ts, - updated_ts=now, - meta=dict(meta or {}), - lifecycle_timestamps={EngineStatus.NEW: now}, - ) - with self.lock: - self.active_requests[request_id] = state - return state - - def _move_to_recent_locked(self, state: EngineRequestState) -> None: - self.recent_requests.appendleft(state) - while len(self.recent_requests) > self.recent_limit: - self.recent_requests.pop() - - @staticmethod - def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: - if not extra: - return - payload = dict(extra) - backend = payload.pop("backend", None) - if backend is not None: - state.backend = str(backend) - finish_reason = payload.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = payload.pop("error", None) - if error is not None: - state.error = str(error) - state.profile.update(payload) - - def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - return - state.status = str(status) - state.updated_ts = now - state.lifecycle_timestamps[str(status)] = now - self._apply_state_extra(state, extra) - - def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - if not extra: - return - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - for recent_state in self.recent_requests: - if recent_state.request_id == request_id: - state = recent_state - break - if state is None: - return - state.updated_ts = now - self._apply_state_extra(state, extra) - - def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.COMPLETED - state.updated_ts = now - state.lifecycle_timestamps[EngineStatus.COMPLETED] = now - self._apply_state_extra(state, extra) - self._move_to_recent_locked(state) - - def fail(self, request_id: str, error: str) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.FAILED - state.updated_ts = now - state.error = str(error) - state.lifecycle_timestamps[EngineStatus.FAILED] = now - self._move_to_recent_locked(state) - - def snapshot(self) -> Dict[str, Any]: - with self.lock: - active = [state.to_summary() for state in self.active_requests.values()] - recent = [state.to_summary() for state in list(self.recent_requests)] - recent_limit = self.recent_limit - active.sort(key=lambda item: item["submit_ts"]) - return { - "active_count": len(active), - "recent_count": len(recent), - "recent_limit": recent_limit, - "active_requests": active, - "recent_requests": recent, - } - - def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - requested = set(request_ids) - results: List[Dict[str, Any]] = [] - with self.lock: - for state in self.active_requests.values(): - if state.request_id in requested: - results.append(state.to_summary()) - existing_ids = {item["request_id"] for item in results} - for state in self.recent_requests: - if state.request_id in requested and state.request_id not in existing_ids: - results.append(state.to_summary()) - results.sort(key=lambda item: item["request_id"]) - return results - - def has_active(self, request_id: str) -> bool: - with self.lock: - return request_id in self.active_requests - - -@dataclass -class SchedulerPendingJob: - request_id: str - state: T2SRequestState - done_event: threading.Event - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - enqueue_time: float - speed_factor: float - sample_steps: int - media_type: str - admission_wait_ms: float = 0.0 - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - prepare_wall_ms: float = 0.0 - prepare_profile_total_ms: float = 0.0 - first_schedule_time: float | None = None - prefill_ms: float = 0.0 - merge_ms: float = 0.0 - decode_ms: float = 0.0 - finalize_wait_ms: float = 0.0 - synth_ms: float = 0.0 - pack_ms: float = 0.0 - decode_steps: int = 0 - result_ready_time: float | None = None - result: dict | None = None - sample_rate: int | None = None - audio_data: np.ndarray | None = None - error: str | None = None - engine_request_id: str | None = None - - -class SchedulerJobRegistry: - def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: - self._lock = lock - self._job_map: Dict[str, SchedulerPendingJob] = {} - self._total_submitted = 0 - self._total_finished = 0 - - def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: - with self._lock: - if keep_job: - self._job_map[job.request_id] = job - self._total_submitted += 1 - - def get(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.get(request_id) - - def pop(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.pop(request_id, None) - - def remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - - def mark_finished(self) -> None: - with self._lock: - self._total_finished += 1 - - def mark_finished_and_remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - self._total_finished += 1 - - def is_empty(self) -> bool: - with self._lock: - return not self._job_map - - def submitted_count(self) -> int: - with self._lock: - return int(self._total_submitted) - - def finished_count(self) -> int: - with self._lock: - return int(self._total_finished) - - def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: - with self._lock: - request_ids = list(self._job_map.keys()) - return { - "job_count": int(len(request_ids)), - "request_ids": request_ids[: max(0, int(max_request_ids))], - "total_submitted": int(self._total_submitted), - "total_finished": int(self._total_finished), - } - - -class EngineTaskQueueOwner: - def __init__(self, completion_key: str = "total_completed") -> None: - self.condition = threading.Condition() - self.queue: Deque[Any] = deque() - self.total_submitted = 0 - self.total_completed = 0 - self.peak_waiting = 0 - self.completion_key = str(completion_key) - - def enqueue(self, item: Any) -> None: - with self.condition: - self.queue.append(item) - self.total_submitted += 1 - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def enqueue_many(self, items: Sequence[Any]) -> None: - if not items: - return - with self.condition: - for item in items: - self.queue.append(item) - self.total_submitted += len(items) - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def pop_left(self) -> Any | None: - with self.condition: - if not self.queue: - return None - return self.queue.popleft() - - def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: - if count <= 0: - return - with self.condition: - self.total_completed += int(count) - if notify: - self.condition.notify_all() - - def has_items(self) -> bool: - with self.condition: - return bool(self.queue) - - def waiting_count(self) -> int: - with self.condition: - return int(len(self.queue)) - - def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - with self.condition: - waiting_items = list(self.queue)[: max(0, int(max_request_ids))] - snapshot = { - "waiting_count": int(len(self.queue)), - "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], - "peak_waiting": int(self.peak_waiting), - "total_submitted": int(self.total_submitted), - self.completion_key: int(self.total_completed), - } - if extra: - snapshot.update(dict(extra)) - return snapshot - - def peek_oldest_age_ms(self, timestamp_attr: str) -> float: - with self.condition: - if not self.queue: - return 0.0 - enqueue_time = float(getattr(self.queue[0], timestamp_attr)) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def is_drained(self) -> bool: - with self.condition: - return not self.queue and self.total_submitted == self.total_completed - - def take_finalize_batch( - self, - *, - finalize_mode: str, - batch_max_items: int, - batch_wait_s: float, - use_vocoder: bool, - ) -> List[SchedulerFinalizeTask]: - with self.condition: - if not self.queue: - return [] - selected_tasks = [self.queue.popleft()] - if finalize_mode == "sync" or use_vocoder: - return selected_tasks - if batch_max_items <= 1: - return selected_tasks - first_task = selected_tasks[0] - oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) - if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: - self.queue.appendleft(first_task) - return [] - while len(selected_tasks) < batch_max_items: - if not self.queue: - break - matched_index = None - for index, task in enumerate(self.queue): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is None: - break - selected_tasks.append(self.queue[matched_index]) - del self.queue[matched_index] - return selected_tasks - - -class EnginePolicyArbiterController: - def __init__( - self, - *, - policy_config: EnginePolicyConfig, - arbiter_config: EngineArbiterConfig, - snapshot_request_registry: Callable[[], Dict[str, Any]], - get_worker_state: Callable[[], Dict[str, Any]], - snapshot_prepare_state: Callable[[], Dict[str, Any]], - snapshot_finalize_state: Callable[[], Dict[str, Any]], - snapshot_dispatch_state: Callable[[], Dict[str, Any]], - snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], - snapshot_job_registry: Callable[[], Dict[str, Any]], - peek_queue_age_ms: Callable[[str], float], - merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], - ) -> None: - self.policy_config = policy_config - self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) - self.arbiter_config = arbiter_config - self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) - self.condition = threading.Condition() - self.state = EngineArbiterState( - decode_budget_remaining=int(self.arbiter_config.decode_burst), - last_observed_at=time.perf_counter(), - ) - self.snapshot_request_registry = snapshot_request_registry - self.get_worker_state = get_worker_state - self.snapshot_prepare_state = snapshot_prepare_state - self.snapshot_finalize_state = snapshot_finalize_state - self.snapshot_dispatch_state = snapshot_dispatch_state - self.snapshot_decode_runtime_state = snapshot_decode_runtime_state - self.snapshot_job_registry = snapshot_job_registry - self.peek_queue_age_ms = peek_queue_age_ms - self.merge_request_state_profile = merge_request_state_profile - - def snapshot_state(self) -> Dict[str, Any]: - with self.condition: - return { - "config": self.arbiter_config.to_dict(), - "total_ticks": int(self.state.total_ticks), - "total_idle_ticks": int(self.state.total_idle_ticks), - "total_prepare_dispatches": int(self.state.total_prepare_dispatches), - "total_decode_dispatches": int(self.state.total_decode_dispatches), - "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), - "total_finalize_dispatches": int(self.state.total_finalize_dispatches), - "decode_budget_remaining": int(self.state.decode_budget_remaining), - "last_stage": str(self.state.last_stage), - "last_reason": str(self.state.last_reason), - "last_policy_allowed": bool(self.state.last_policy_allowed), - "last_observed_at": float(self.state.last_observed_at), - } - - def notify(self) -> None: - with self.condition: - self.condition.notify_all() - - def wait(self) -> None: - with self.condition: - self.condition.wait(timeout=self.arbiter_poll_s) - - def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - with self.condition: - self.state.total_ticks += 1 - if stage == "idle": - self.state.total_idle_ticks += 1 - elif stage == "prepare": - self.state.total_prepare_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "finalize": - self.state.total_finalize_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "decode_dispatch": - self.state.total_decode_dispatches += 1 - elif stage == "decode_runtime": - self.state.total_decode_runtime_ticks += 1 - self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) - self.state.last_stage = str(stage) - self.state.last_reason = str(reason) - self.state.last_policy_allowed = bool(policy_allowed) - self.state.last_observed_at = time.perf_counter() - - def build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - prepare_dispatcher_state = self.snapshot_prepare_state() - finalize_dispatcher_state = self.snapshot_finalize_state() - dispatcher_state = self.snapshot_dispatch_state() - active_requests = list(request_registry.get("active_requests", [])) - status_counts: Dict[str, int] = {} - for item in active_requests: - status = str(item.get("status", "UNKNOWN")) - status_counts[status] = status_counts.get(status, 0) + 1 - - worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) - worker_decode_active_size = int(worker_state.get("running_requests", 0)) - worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) - worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) - worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) - engine_decode_runtime_state = self.snapshot_decode_runtime_state() - engine_job_registry = self.snapshot_job_registry() - decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) - decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) - return { - "active_request_count": int(len(active_requests)), - "status_counts": status_counts, - "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), - "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), - "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), - "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), - "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), - "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), - "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), - "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), - "worker_pending_jobs": worker_pending_jobs, - "worker_decode_active_size": worker_decode_active_size, - "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), - "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), - "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, - "engine_decode_runtime_active_request_count": decode_runtime_active_size, - "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), - "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), - "worker_prepare_inflight": worker_prepare_inflight, - "worker_finalize_pending": worker_finalize_pending, - "worker_finalize_inflight": worker_finalize_inflight, - "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), - "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), - "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), - "decode_backlog": int( - decode_runtime_pending_jobs + decode_runtime_active_size - if bool(worker_state.get("engine_decode_control_enabled", False)) - else worker_pending_jobs + worker_decode_active_size - ), - } - - def build_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - counters = self.build_stage_counters(request_registry, worker_state) - config = self.policy_config.to_dict() - blocked_reasons: List[Dict[str, Any]] = [] - finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) - limit_checks = [ - ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), - ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), - ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), - ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), - ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), - ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), - ] - if bool(config["enabled"]): - for name, value, limit in limit_checks: - if limit > 0 and int(value) >= int(limit): - blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) - return { - "enabled": bool(config["enabled"]), - "allowed": (not bool(config["enabled"])) or not blocked_reasons, - "blocked_reasons": blocked_reasons, - "config": config, - "metrics": { - "active_request_count": int(counters["active_request_count"]), - "queued_request_count": int(counters["queued_request_count"]), - "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), - "active_decode_request_count": int(counters["active_decode_request_count"]), - "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), - "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), - "decode_backlog": int(counters["decode_backlog"]), - "prepare_inflight": int(counters["worker_prepare_inflight"]), - "finalize_pending": int(finalize_pending_total), - "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), - "finalize_inflight": int(counters["worker_finalize_inflight"]), - }, - "observed_at": time.perf_counter(), - } - - async def wait_for_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if not self.policy_config.enabled: - return 0.0, snapshot - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - while True: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if snapshot["allowed"]: - wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) - if request_id not in [None, ""]: - self.merge_request_state_profile( - str(request_id), - { - "engine_policy_wait_ms": float(wait_ms), - "engine_policy_snapshot": snapshot, - }, - ) - return wait_ms, snapshot - now = time.perf_counter() - if deadline is not None and now >= deadline: - blocked_summary = ", ".join( - f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) - ) - raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") - await asyncio.sleep(self.policy_poll_s) - - def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) - finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) - decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) - decode_runtime_state = self.snapshot_decode_runtime_state() - worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) - worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) - worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) - worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) - prepare_age_ms = float(self.peek_queue_age_ms("prepare")) - finalize_age_ms = float(self.peek_queue_age_ms("finalize")) - decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) - decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) - policy_allowed = bool(policy_snapshot.get("allowed", True)) - if ( - worker_decode_control_enabled - and worker_decode_has_work - and policy_allowed - and decode_budget_remaining > 0 - and (worker_running_requests > 0 or worker_pending_jobs > 0) - ): - return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state - if ( - worker_decode_control_enabled - and worker_pending_jobs > 0 - and policy_allowed - and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) - ): - return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state - if ( - decode_waiting > 0 - and policy_allowed - and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) - ): - return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state - if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): - return "finalize", "finalize_aging", policy_snapshot, worker_state - if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state - if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: - return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state - if decode_waiting > 0 and policy_allowed: - return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state - if finalize_waiting > 0: - return "finalize", "finalize_fallback", policy_snapshot, worker_state - if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state - return "idle", "no_pending_work", policy_snapshot, worker_state - - -class EngineDecodeRuntimeOwner: - def __init__( - self, - *, - get_decode_runtime_counters: Callable[[], Dict[str, int]], - get_micro_batch_wait_s: Callable[[], float], - ) -> None: - self.get_decode_runtime_counters = get_decode_runtime_counters - self.get_micro_batch_wait_s = get_micro_batch_wait_s - self.condition = threading.Condition() - self.pending_jobs: Deque[SchedulerPendingJob] = deque() - self.active_batch: T2SActiveBatch | None = None - self.state_lock = threading.Lock() - self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) - - @staticmethod - def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - if active_batch is None: - return {} - decode_step_index_max = 0 - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: - decode_step_index_max = int(active_batch.step_indices.max().item()) - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": int(decode_step_index_max), - } - - def snapshot_pending_queue_state(self) -> Dict[str, Any]: - with self.condition: - return { - "pending_jobs": int(len(self.pending_jobs)), - "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], - } - - def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: - with self.condition: - self.pending_jobs.append(job) - self.condition.notify_all() - self.refresh_state("engine_decode_pending_enqueue") - - def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): - return [] - pending_jobs = list(self.pending_jobs) - self.pending_jobs.clear() - self.refresh_state("engine_decode_pending_dequeue") - return pending_jobs - - def pending_age_ms(self) -> float: - with self.condition: - if not self.pending_jobs: - return 0.0 - enqueue_time = float(self.pending_jobs[0].enqueue_time) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def has_pending_jobs(self) -> bool: - with self.condition: - return bool(self.pending_jobs) - - def get_active_batch(self) -> T2SActiveBatch | None: - return self.active_batch - - def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: - self.active_batch = active_batch - - def active_batch_summary(self) -> Dict[str, Any]: - return self.summarize_active_batch(self.active_batch) - - def refresh_state(self, last_event: str) -> None: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) - self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] - self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) - self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) - self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) - self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) - self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) - self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) - self.state.last_event = str(last_event) - self.state.updated_at = float(time.perf_counter()) - - def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - pending_state = self.snapshot_pending_queue_state() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(snapshot.get("active_request_count", 0)) - self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] - self.state.prefill_done = bool(snapshot.get("prefill_done", False)) - self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) - self.state.total_cycles = int(snapshot.get("total_cycles", 0)) - self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) - self.state.step_cycles = int(snapshot.get("step_cycles", 0)) - self.state.has_work = bool( - pending_state.get("pending_jobs", 0) - or snapshot.get("active_request_count", 0) - or snapshot.get("has_work", False) - ) - self.state.last_event = str(snapshot.get("last_event", "unknown")) - self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) - - def snapshot_state(self) -> Dict[str, Any]: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - return { - "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), - "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), - "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), - "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), - "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), - "decode_step_index_max": int( - active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) - ), - "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), - "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), - "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), - "has_work": bool( - pending_state.get("pending_jobs", 0) - or active_batch_summary.get("request_count", self.state.active_request_count) - or self.state.has_work - ), - "last_event": str(self.state.last_event), - "updated_at": float(self.state.updated_at), - } - -@dataclass -class SchedulerFinalizeTask: - request_id: str - item: T2SFinishedItem - enqueued_time: float - - -@dataclass -class EngineDispatchTask: - request_id: str - state: T2SRequestState - speed_factor: float - sample_steps: int - media_type: str - prepare_wall_ms: float - prepare_profile_total_ms: float - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - timeout_sec: float | None - enqueue_time: float - worker_job: SchedulerPendingJob | None = None - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - engine_policy_snapshot: Dict[str, Any] | None = None - error: str | None = None - - -@dataclass -class EngineGpuPrepareTask: - request_id: str - cpu_stage: PreparedCpuStage - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - enqueue_time: float - queue_wait_ms: float = 0.0 - error: str | None = None - - -@dataclass -class EngineFinalizeQueueState: - waiting_count: int - waiting_request_ids: List[str] - peak_waiting: int - total_submitted: int - total_completed: int - - -@dataclass -class EngineArbiterState: - total_ticks: int = 0 - total_idle_ticks: int = 0 - total_prepare_dispatches: int = 0 - total_decode_dispatches: int = 0 - total_decode_runtime_ticks: int = 0 - total_finalize_dispatches: int = 0 - decode_budget_remaining: int = 0 - last_stage: str = "idle" - last_reason: str = "init" - last_observed_at: float = 0.0 - last_policy_allowed: bool = True - - -@dataclass -class EngineDecodeRuntimeState: - pending_jobs: int = 0 - pending_request_ids: List[str] = field(default_factory=list) - active_request_count: int = 0 - active_request_ids: List[str] = field(default_factory=list) - prefill_done: bool = False - decode_step_index_max: int = 0 - total_cycles: int = 0 - prefill_cycles: int = 0 - step_cycles: int = 0 - has_work: bool = False - last_event: str = "init" - updated_at: float = 0.0 - - -@dataclass -class RuntimeStateCallbacks: - update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None - complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None - fail: Callable[[str, str], None] | None = None - decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None - - -class WorkerPrepareExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - ) -> None: - self.coordinator = PrepareCoordinator(tts) - self.on_state_change = on_state_change - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def snapshot(self) -> Dict[str, int]: - return dict(self.coordinator.snapshot()) - - def get_max_inflight(self) -> int: - return int(self.coordinator.snapshot().get("max_inflight", 0)) - - def is_idle(self) -> bool: - return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - results = await asyncio.gather( - *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] - ) - return [state for state, _, _ in results] - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - try: - return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) - finally: - self._notify_state_change() - - -class WorkerFinalizeExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, - ) -> None: - self.tts = tts - self.on_state_change = on_state_change - self.external_submit = external_submit - self.condition = threading.Condition() - self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() - self.pending_peak = 0 - self.inflight = 0 - self.inflight_peak = 0 - self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def get_worker_count(self) -> int: - return int(self.worker_count) - - def get_batch_policy(self) -> Dict[str, Any]: - return { - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_s": float(self.batch_wait_s), - } - - def get_pending_count(self) -> int: - with self.condition: - return int(len(self.pending_tasks)) - - def snapshot(self) -> Dict[str, Any]: - with self.condition: - return { - "finalize_pending": int(len(self.pending_tasks)), - "finalize_pending_peak": int(self.pending_peak), - "finalize_inflight": int(self.inflight), - "finalize_inflight_peak": int(self.inflight_peak), - "finalize_workers": int(self.worker_count), - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), - } - - def is_idle(self) -> bool: - with self.condition: - return self.inflight <= 0 and not self.pending_tasks - - def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - if self.external_submit is not None: - self.external_submit(tasks) - self._notify_state_change() - return - with self.condition: - for task in tasks: - self.pending_tasks.append(task) - self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) - self.condition.notify_all() - self._notify_state_change() - - def begin_execution(self, task_count: int) -> None: - if task_count <= 0: - return - with self.condition: - self.inflight += int(task_count) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self.condition.notify_all() - self._notify_state_change() - - def end_execution(self, task_count: int) -> None: - with self.condition: - self.inflight = max(0, self.inflight - int(task_count)) - self.condition.notify_all() - self._notify_state_change() - - def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: - with self.condition: - while not self.pending_tasks: - self.condition.wait() - selected_tasks = [self.pending_tasks.popleft()] - if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - batch_deadline = time.perf_counter() + self.batch_wait_s - while len(selected_tasks) < self.batch_max_items: - if not self.pending_tasks: - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - continue - first_task = selected_tasks[0] - matched_index = None - for index, task in enumerate(self.pending_tasks): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is not None: - selected_tasks.append(self.pending_tasks[matched_index]) - del self.pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - audio_fragment = self.tts.synthesize_audio_request_local( - semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), - phones=job.state.phones.detach().clone().unsqueeze(0), - prompt_semantic=job.state.prompt_semantic.detach().clone(), - prompt_phones=job.state.prompt_phones.detach().clone(), - refer_spec=( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ), - raw_audio=job.state.raw_audio.detach().clone(), - raw_sr=int(job.state.raw_sr), - speed=float(job.speed_factor), - sample_steps=int(job.sample_steps), - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - return self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - - def _synthesize_finished_audio_batch( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> List[tuple[int, np.ndarray]]: - semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] - phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] - refer_specs = [] - speeds = [] - sample_steps_list = [] - for job, _ in jobs_and_items: - refer_specs.append( - ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) - ) - speeds.append(float(job.speed_factor)) - sample_steps_list.append(int(job.sample_steps)) - audio_fragments = self.tts.synthesize_audio_requests_local_batched( - semantic_tokens_list=semantic_tokens_list, - phones_list=phones_list, - refer_specs=refer_specs, - speeds=speeds, - sample_steps_list=sample_steps_list, - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - results: List[tuple[int, np.ndarray]] = [] - for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): - results.append( - self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - ) - return results - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - if not jobs_and_items: - return 0.0, [] - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - return float(synth_ms), batch_results - - -class WorkerCompletionBridge: - def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def notify_done_future(self, job: SchedulerPendingJob) -> None: - if job.done_loop is None or job.done_future is None: - return - try: - job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) - except RuntimeError: - pass - - def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.complete is None: - return - self.runtime_callbacks.complete(request_id, extra) - - def runtime_fail(self, request_id: str | None, error: str) -> None: - if request_id is None or self.runtime_callbacks.fail is None: - return - self.runtime_callbacks.fail(request_id, error) - - @staticmethod - def build_completed_job_result( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - finished_at: float | None = None, - ) -> Dict[str, Any]: - finished_at = float(time.perf_counter() if finished_at is None else finished_at) - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "decode_admission_wait_ms": float(job.admission_wait_ms), - "engine_policy_wait_ms": float(job.engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.result = result - return result - - @staticmethod - def build_runtime_complete_payload( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - ) -> Dict[str, Any]: - return { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "sample_rate": int(sample_rate), - "worker_profile": dict(job.result or {}), - } - - def complete_job( - self, - job: SchedulerPendingJob, - *, - runtime_request_id: str | None, - runtime_extra: Optional[Dict[str, Any]] = None, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_complete(runtime_request_id, runtime_extra) - - def fail_job( - self, - job: SchedulerPendingJob, - *, - error: str, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.error = str(error) - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_fail(job.engine_request_id, str(error)) - - def complete_finalize_task( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - job: SchedulerPendingJob, - item: T2SFinishedItem, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - runtime_extra: Optional[Dict[str, Any]] = None - with condition: - if job_registry.get(item.request_id) is not job: - return - self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) - self.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=runtime_extra, - on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), - notify_waiters=condition.notify_all, - ) - - def fail_jobs( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - request_ids: List[str], - error: str, - ) -> None: - if not request_ids: - return - with condition: - for request_id in request_ids: - job = job_registry.get(request_id) - if job is None: - continue - self.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), - ) - condition.notify_all() - - -class WorkerDecodeExecutor: - def __init__(self, tts: TTS, max_steps: int) -> None: - self.tts = tts - self.max_steps = int(max_steps) - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def execute_prefill_merge( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if not pending_jobs: - return { - "executed": False, - "active_batch": active_batch, - "pending_jobs": [], - "prefill_elapsed_s": 0.0, - "merge_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - admitted_finished: List[T2SFinishedItem] = [] - prefill_elapsed_s = 0.0 - merge_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - self._sync_device() - prefill_start = time.perf_counter() - mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - prefill_elapsed_s = time.perf_counter() - prefill_start - if add_prefill_time is not None: - add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(admitted_finished) - merge_start = time.perf_counter() - active_batch = merge_active_batches( - self.tts.t2s_model.model, - active_batch, - admitted_active_batch, - ) - merge_elapsed_s = time.perf_counter() - merge_start - if add_merge_time is not None: - add_merge_time( - [] if active_batch is None else list(active_batch.request_ids), - merge_elapsed_s, - ) - except Exception as exc: - error = str(exc) - error_request_ids = [job.request_id for job in pending_jobs] - if finalize_error is not None: - finalize_error(error_request_ids, error) - return { - "executed": True, - "active_batch": active_batch, - "pending_jobs": list(pending_jobs), - "prefill_elapsed_s": float(prefill_elapsed_s), - "merge_elapsed_s": float(merge_elapsed_s), - "finished_items": list(admitted_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_step( - self, - *, - active_batch: Optional[T2SActiveBatch], - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if active_batch is None: - return { - "executed": False, - "active_batch": None, - "request_ids": [], - "decode_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - active_request_ids: List[str] = [] - step_finished: List[T2SFinishedItem] = [] - decode_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - active_request_ids = [state.request_id for state in active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - decode_elapsed_s = time.perf_counter() - decode_start - if add_decode_time is not None: - add_decode_time(active_request_ids, decode_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(step_finished) - except Exception as exc: - error = str(exc) - error_request_ids = list(active_request_ids) - if finalize_error is not None: - finalize_error(error_request_ids, error) - active_batch = None - return { - "executed": True, - "active_batch": active_batch, - "request_ids": active_request_ids, - "decode_elapsed_s": float(decode_elapsed_s), - "finished_items": list(step_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_cycle( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - result = { - "executed": False, - "prefill_merge_executed": False, - "decode_step_executed": False, - "active_batch": active_batch, - "prefill_phase": {}, - "decode_phase": {}, - } - prefill_phase = self.execute_prefill_merge( - pending_jobs=list(pending_jobs), - active_batch=result["active_batch"], - mark_prefill_started=mark_prefill_started, - add_prefill_time=add_prefill_time, - add_merge_time=add_merge_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - prefill_executed = bool(prefill_phase.get("executed", False)) - result["prefill_phase"] = prefill_phase - result["active_batch"] = prefill_phase.get("active_batch") - if prefill_executed: - result["executed"] = True - result["prefill_merge_executed"] = True - decode_phase = self.execute_decode_step( - active_batch=result["active_batch"], - add_decode_time=add_decode_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - decode_executed = bool(decode_phase.get("executed", False)) - result["decode_phase"] = decode_phase - result["active_batch"] = decode_phase.get("active_batch") - if decode_executed: - result["executed"] = True - result["decode_step_executed"] = True - return result - - -class WorkerDecodeLegacyShell: - def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: - self.condition = condition - self.micro_batch_wait_s = float(micro_batch_wait_s) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: - if active_batch is None: - return None - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": ( - int(active_batch.step_indices.max().item()) - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 - else 0 - ), - } - - def current_backlog_locked(self) -> int: - running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) - return int(len(self.pending_jobs) + running_requests) - - def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: - self.pending_jobs.append(job) - - def snapshot_locked(self) -> Dict[str, Any]: - active_batch_summary = self._summarize_active_batch(self.active_batch) - executor_local_pending_jobs = int(len(self.pending_jobs)) - executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) - executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) - return { - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "executor_local_active_batch": active_batch_summary, - } - - def is_idle_locked(self) -> bool: - return self.active_batch is None and not self.pending_jobs - - def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs and self.active_batch is None: - self.condition.wait(timeout=self.micro_batch_wait_s) - elif wait_for_batch and self.pending_jobs: - self.condition.wait(timeout=self.micro_batch_wait_s) - if not self.pending_jobs: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def has_decode_runtime_work(self) -> bool: - with self.condition: - return bool(self.pending_jobs or self.active_batch is not None) - - def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: - active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) - decode_step_index_max = 0 - prefill_done = False - if self.active_batch is not None: - prefill_done = bool(self.active_batch.prefill_done) - if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: - decode_step_index_max = int(self.active_batch.step_indices.max().item()) - return { - "pending_jobs": int(len(self.pending_jobs)), - "active_request_count": int(len(active_request_ids)), - "active_request_ids": active_request_ids[:32], - "prefill_done": bool(prefill_done), - "decode_step_index_max": int(decode_step_index_max), - "total_cycles": int(total_cycles), - "prefill_cycles": int(prefill_cycles), - "step_cycles": int(step_cycles), - "has_work": bool(self.pending_jobs or self.active_batch is not None), - "last_event": str(last_event), - "updated_at": float(time.perf_counter()), - } - - def run_prefill_merge_once_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_prefill_merge(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_step_once_nonblocking( - self, - *, - external_active_batch: Optional[T2SActiveBatch], - execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - active_batch = self.active_batch if external_active_batch is None else external_active_batch - result = execute_decode_step(active_batch) - if external_active_batch is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_cycle_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - on_cycle_executed: Callable[[Dict[str, Any]], None] | None, - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_decode_cycle(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - if result.get("executed") and on_cycle_executed is not None: - on_cycle_executed(result) - return result - - def run_loop( - self, - *, - run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], - ) -> None: - while True: - executed = run_decode_cycle_nonblocking() - if executed.get("executed"): - continue - wait_for_batch = self.active_batch is None - pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) - if pending_jobs: - with self.condition: - self.pending_jobs = pending_jobs + self.pending_jobs - self.condition.notify_all() - continue - time.sleep(self.micro_batch_wait_s) - - -class WorkerDecodeRuntimeTracker: - def __init__( - self, - runtime_callbacks: RuntimeStateCallbacks | None = None, - ) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.total_cycles = 0 - self.prefill_cycles = 0 - self.step_cycles = 0 - - def get_counters(self) -> Dict[str, int]: - return { - "total_cycles": int(self.total_cycles), - "prefill_cycles": int(self.prefill_cycles), - "step_cycles": int(self.step_cycles), - } - - def record_cycle(self, result: Dict[str, Any]) -> None: - if not bool(result.get("executed")): - return - self.total_cycles += 1 - if bool(result.get("prefill_merge_executed")): - self.prefill_cycles += 1 - if bool(result.get("decode_step_executed")): - self.step_cycles += 1 - - def build_runtime_summary_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> Dict[str, Any]: - return legacy_shell.build_runtime_summary_locked( - total_cycles=int(self.total_cycles), - prefill_cycles=int(self.prefill_cycles), - step_cycles=int(self.step_cycles), - last_event=str(last_event), - ) - - def notify_runtime_update_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> None: - if self.runtime_callbacks.decode_runtime_update is None: - return - snapshot = self.build_runtime_summary_locked( - legacy_shell=legacy_shell, - last_event=last_event, - ) - self.runtime_callbacks.decode_runtime_update(snapshot) - - -class UnifiedSchedulerWorker: - def __init__( - self, - tts: TTS, - max_steps: int = 1500, - micro_batch_wait_ms: int = 5, - runtime_callbacks: RuntimeStateCallbacks | None = None, - external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, - ): - self.tts = tts - self.max_steps = int(max_steps) - self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.condition = threading.Condition() - self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks) - self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps) - self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s) - self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks) - self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change) - self.finalize_executor = WorkerFinalizeExecutor( - tts, - on_state_change=self._notify_worker_state_change, - external_submit=external_finalize_submit, - ) - self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) - self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) - self.engine_decode_control_enabled = ( - str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} - ) - self.job_registry = SchedulerJobRegistry(self.condition) - self.worker_thread: threading.Thread | None = None - if not self.engine_decode_control_enabled: - self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) - self.worker_thread.start() - self.finalize_threads = [] - if external_finalize_submit is None: - self.finalize_threads = [ - threading.Thread( - target=self._run_finalize_loop, - name=f"unified-t2s-finalize-{worker_index}", - daemon=True, - ) - for worker_index in range(self.finalize_executor.get_worker_count()) - ] - for finalize_thread in self.finalize_threads: - finalize_thread.start() - - def _notify_worker_state_change(self) -> None: - with self.condition: - self.condition.notify_all() - - def _current_decode_backlog_locked(self) -> int: - return self.decode_legacy_shell.current_backlog_locked() - - def get_micro_batch_wait_s(self) -> float: - return float(self.micro_batch_wait_s) - - def is_engine_decode_control_enabled(self) -> bool: - return bool(self.engine_decode_control_enabled) - - def get_prepare_max_inflight(self) -> int: - return int(self.prepare_executor.get_max_inflight()) - - def get_capacity_limits(self) -> Dict[str, int]: - return { - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - } - - def get_finalize_batch_policy(self) -> Dict[str, Any]: - return dict(self.finalize_executor.get_batch_policy()) - - def get_decode_runtime_counters(self) -> Dict[str, int]: - with self.condition: - return self.decode_runtime_tracker.get_counters() - - def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: - decode_backlog = self._current_decode_backlog_locked() - finalize_pending = int(self.finalize_executor.get_pending_count()) - prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) - blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max - blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max - return ( - not blocked_decode and not blocked_finalize, - { - "decode_backlog": decode_backlog, - "finalize_pending": finalize_pending, - "prepare_inflight": prepare_inflight, - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - }, - ) - - def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - last_snapshot: Dict[str, int] = {} - while True: - with self.condition: - allowed, snapshot = self._can_accept_submit_locked() - last_snapshot = snapshot - if allowed: - return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot - if deadline is not None and time.perf_counter() >= deadline: - raise TimeoutError( - "scheduler submit admission timeout " - f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" - ) - self.condition.wait(timeout=self.micro_batch_wait_s) - - def _admission_snapshot_locked(self) -> Dict[str, int]: - _, snapshot = self._can_accept_submit_locked() - return snapshot - - async def submit_async( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - return await asyncio.to_thread( - self.submit, - state, - speed_factor, - sample_steps, - media_type, - prepare_wall_ms, - prepare_profile_total_ms, - done_loop, - done_future, - engine_request_id, - timeout_sec, - skip_capacity_wait, - admission_wait_ms_override, - admission_snapshot_override, - engine_policy_wait_ms, - engine_dispatch_wait_ms, - enqueue_pending, - ) - - def snapshot(self) -> dict: - with self.condition: - prepare_state = self.prepare_executor.snapshot() - finalize_state = self.finalize_executor.snapshot() - shell_state = self.decode_legacy_shell.snapshot_locked() - decode_runtime_counters = self.decode_runtime_tracker.get_counters() - engine_owned_decode_state = bool(self.engine_decode_control_enabled) - active_batch_summary = shell_state.get("executor_local_active_batch") - executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) - executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) - executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) - return { - "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, - "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, - "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), - "legacy_state_owner_mode": not engine_owned_decode_state, - "decode_state_owner": "engine" if engine_owned_decode_state else "worker", - "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), - "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), - "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), - "prepare_inflight": prepare_state["inflight"], - "prepare_peak_inflight": prepare_state["peak_inflight"], - "prepare_max_inflight": prepare_state.get("max_inflight", 0), - "prepare_state": dict(prepare_state), - **finalize_state, - "decode_backlog_max": self.decode_backlog_max, - "finalize_pending_max": self.finalize_pending_max, - "active_batch": {} if engine_owned_decode_state else active_batch_summary, - "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, - "total_submitted": self.job_registry.submitted_count(), - "total_finished": self.job_registry.finished_count(), - "drained": self.is_drained(), - } - - def is_drained(self) -> bool: - with self.condition: - return ( - self.decode_legacy_shell.is_idle_locked() - and self.job_registry.is_empty() - and self.prepare_executor.is_idle() - and self.finalize_executor.is_idle() - ) - - def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: - deadline = time.perf_counter() + max(0.0, timeout_sec) - while time.perf_counter() < deadline: - if self.is_drained(): - return True - time.sleep(poll_interval_sec) - return self.is_drained() - - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - if skip_capacity_wait: - with self.condition: - admission_snapshot = ( - dict(admission_snapshot_override) - if admission_snapshot_override is not None - else dict(self._admission_snapshot_locked()) - ) - admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) - else: - admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - admission_wait_ms=float(admission_wait_ms), - engine_policy_wait_ms=float(engine_policy_wait_ms), - engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - engine_request_id=engine_request_id or state.request_id, - ) - with self.condition: - self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) - if enqueue_pending: - self.decode_legacy_shell.enqueue_pending_job_locked(job) - self.condition.notify_all() - if enqueue_pending: - self._notify_decode_runtime_state("submit") - self._runtime_update( - job.engine_request_id, - EngineStatus.QUEUED, - { - "scheduler_request_id": job.request_id, - "decode_admission_wait_ms": float(admission_wait_ms), - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), - "admission_snapshot": dict(admission_snapshot), - }, - ) - return job - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return await self.prepare_executor.prepare_states_batch_async(specs) - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) - - def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in pending_jobs: - job.first_schedule_time = float(started_at) - self._runtime_update( - job.engine_request_id, - EngineStatus.GPU_PREPARING, - {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, - ) - - def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.prefill_ms += delta_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - activate_request_ids: List[str] = [] - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.finalize_wait_ms += float(delta_ms) - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks: List[SchedulerFinalizeTask] = [] - with self.condition: - for item in items: - job = self.job_registry.get(item.request_id) - if job is not None: - self._runtime_update( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - }, - ) - tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) - self.finalize_executor.enqueue_tasks(tasks) - - def begin_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.begin_execution(task_count) - - def end_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.end_execution(task_count) - - def record_external_job_done(self, request_id: str) -> None: - with self.condition: - self.job_registry.mark_finished_and_remove(request_id) - self.condition.notify_all() - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) - - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - self.completion_bridge.complete_finalize_task( - condition=self.condition, - job_registry=self.job_registry, - job=job, - item=item, - sample_rate=sample_rate, - audio_data=audio_data, - ) - - def _finalize_error(self, request_ids: List[str], error: str) -> None: - self.completion_bridge.fail_jobs( - condition=self.condition, - job_registry=self.job_registry, - request_ids=request_ids, - error=error, - ) - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def _notify_done_future(self, job: SchedulerPendingJob) -> None: - self.completion_bridge.notify_done_future(job) - - def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.update is None: - return - self.runtime_callbacks.update(request_id, status, extra) - - def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - self.completion_bridge.runtime_complete(request_id, extra) - - def _runtime_fail(self, request_id: str | None, error: str) -> None: - self.completion_bridge.runtime_fail(request_id, error) - - def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: - return self.decode_runtime_tracker.build_runtime_summary_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _notify_decode_runtime_state(self, last_event: str) -> None: - with self.condition: - self.decode_runtime_tracker.notify_runtime_update_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: - with self.condition: - self.decode_runtime_tracker.record_cycle(result) - - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) - - def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) - - def has_decode_runtime_work(self) -> bool: - return self.decode_legacy_shell.has_decode_runtime_work() - - def execute_prefill_merge( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_prefill_merge( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_step( - self, - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_decode_step( - active_batch=active_batch, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_cycle( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_executor.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - self._record_decode_runtime_cycle(result) - return result - - def run_prefill_merge_once_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("prefill_merge") - return result - - def run_decode_step_once_nonblocking( - self, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_step_once_nonblocking( - external_active_batch=external_active_batch, - execute_decode_step=lambda batch_state: self.execute_decode_step( - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("decode_step") - return result - - def run_decode_cycle_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_cycle_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - on_cycle_executed=None, - ) - if result.get("executed") and emit_runtime_state: - self._notify_decode_runtime_state("decode_cycle") - return result - - def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - with self.condition: - for task in tasks: - job = self.job_registry.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - for job, item in jobs_and_items: - self._runtime_update( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_registry.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self.finalize_executor.end_execution(len(tasks)) - - def _run_finalize_loop(self) -> None: - while True: - tasks = self.finalize_executor.take_task_batch_blocking() - self.execute_finalize_tasks(tasks) - - def _run_loop(self) -> None: - self.decode_legacy_shell.run_loop( - run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() - ) - - -def set_scheduler_seed(seed: int): - if seed in ["", None]: - return - seed = int(seed) - if seed < 0: - return - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - def handle_pack_ogg(): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) - - stack_size = 4096 * 4096 - try: - threading.stack_size(stack_size) - pack_ogg_thread = threading.Thread(target=handle_pack_ogg) - pack_ogg_thread.start() - pack_ogg_thread.join() - except (RuntimeError, ValueError): - handle_pack_ogg() - return io_buffer - - -def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer.write(data.tobytes()) - return io_buffer - - -def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer = BytesIO() - sf.write(io_buffer, data, rate, format="wav") - return io_buffer - - -def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): - process = subprocess.Popen( - [ - "ffmpeg", - "-f", - "s16le", - "-ar", - str(rate), - "-ac", - "1", - "-i", - "pipe:0", - "-c:a", - "aac", - "-b:a", - "192k", - "-vn", - "-f", - "adts", - "pipe:1", - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, _ = process.communicate(input=data.tobytes()) - io_buffer.write(out) - return io_buffer - - -def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): - if media_type == "ogg": - io_buffer = pack_ogg(io_buffer, data, rate) - elif media_type == "aac": - io_buffer = pack_aac(io_buffer, data, rate) - elif media_type == "wav": - io_buffer = pack_wav(io_buffer, data, rate) - else: - io_buffer = pack_raw(io_buffer, data, rate) - io_buffer.seek(0) - return io_buffer - - -def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): - wav_buf = BytesIO() - with wave.open(wav_buf, "wb") as vfout: - vfout.setnchannels(channels) - vfout.setsampwidth(sample_width) - vfout.setframerate(sample_rate) - vfout.writeframes(frame_input) - wav_buf.seek(0) - return wav_buf.read() - - -class UnifiedTTSEngine: +class UnifiedTTSEngine(EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): @staticmethod def _env_flag(name: str, default: bool) -> bool: value = os.environ.get(name) @@ -2783,2048 +42,4 @@ class UnifiedTTSEngine: self.tts = tts self.cut_method_names = set(cut_method_names) self.control_callbacks = control_callbacks or RuntimeControlCallbacks() - self.reference_registry = ReferenceRegistry() - self.model_registry = ModelRegistry( - t2s_weights_path=str(self.tts.configs.t2s_weights_path), - vits_weights_path=str(self.tts.configs.vits_weights_path), - ) - self.request_registry = EngineRequestRegistry( - recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) - ) - self.engine_job_registry = SchedulerJobRegistry(threading.Lock()) - self.scheduler_worker = UnifiedSchedulerWorker( - tts, - max_steps=max_steps, - micro_batch_wait_ms=micro_batch_wait_ms, - runtime_callbacks=RuntimeStateCallbacks( - update=self._update_request_state, - complete=self._complete_request_state, - fail=self._fail_request_state, - decode_runtime_update=self._update_engine_decode_runtime_state, - ), - external_finalize_submit=self._enqueue_worker_finished_for_finalize, - ) - self.direct_tts_lock = threading.RLock() - self.management_lock = threading.RLock() - worker_capacity_limits = self.scheduler_worker.get_capacity_limits() - prepare_max_inflight = int(self.scheduler_worker.get_prepare_max_inflight()) - self.engine_policy_config = EnginePolicyConfig( - enabled=self._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True), - poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))), - decode_backlog_soft_max=max( - 0, - self._env_int( - "GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX", - int(worker_capacity_limits["decode_backlog_max"]), - ), - ), - finalize_pending_soft_max=max( - 0, - self._env_int( - "GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX", - int(worker_capacity_limits["finalize_pending_max"]), - ), - ), - prepare_inflight_soft_max=max( - 0, - self._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight), - ), - active_decode_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)), - ready_for_prefill_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)), - active_request_soft_max=max(0, self._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)), - ) - self.engine_arbiter_config = EngineArbiterConfig( - poll_wait_ms=max(1.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))), - decode_burst=max(1, self._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)), - prepare_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)), - finalize_aging_ms=max(0.0, self._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)), - ) - self.engine_decode_runtime_owner = EngineDecodeRuntimeOwner( - get_decode_runtime_counters=self.scheduler_worker.get_decode_runtime_counters, - get_micro_batch_wait_s=self.scheduler_worker.get_micro_batch_wait_s, - ) - self.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") - self.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") - self.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") - self.engine_dispatch_last_snapshot: Dict[str, Any] = {} - self.engine_policy_arbiter = EnginePolicyArbiterController( - policy_config=self.engine_policy_config, - arbiter_config=self.engine_arbiter_config, - snapshot_request_registry=self._snapshot_request_registry, - get_worker_state=self.get_scheduler_state, - snapshot_prepare_state=self._snapshot_engine_prepare_state, - snapshot_finalize_state=self._snapshot_engine_finalize_state, - snapshot_dispatch_state=self._snapshot_engine_dispatch_state, - snapshot_decode_runtime_state=self._snapshot_engine_decode_runtime_state, - snapshot_job_registry=self._snapshot_engine_job_registry, - peek_queue_age_ms=self._peek_queue_age_ms, - merge_request_state_profile=self._merge_request_state_profile, - ) - self.engine_arbiter_thread = threading.Thread( - target=self._run_engine_arbiter_loop, - name="unified-engine-arbiter", - daemon=True, - ) - self.engine_arbiter_thread.start() - - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.request_registry.register( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state( - self, - request_id: str, - status: str, - extra: Optional[Dict[str, Any]] = None, - ) -> None: - self.request_registry.update(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.merge_profile(request_id, extra) - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.engine_dispatch_queue_owner.snapshot( - max_request_ids=16, - extra={"last_policy_snapshot": dict(self.engine_dispatch_last_snapshot or {})}, - ) - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.engine_job_registry.register(job, keep_job=True) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.get(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.pop(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.engine_job_registry.snapshot(max_request_ids=32) - - def _is_engine_drained(self) -> bool: - prepare_empty = self.engine_prepare_queue_owner.is_drained() - dispatch_empty = self.engine_dispatch_queue_owner.is_drained() - finalize_empty = self.engine_finalize_queue_owner.is_drained() - decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() - job_empty = self.engine_job_registry.is_empty() - worker_state = self.scheduler_worker.snapshot() - return bool( - prepare_empty - and dispatch_empty - and finalize_empty - and decode_pending_empty - and job_empty - and self.engine_decode_runtime_owner.get_active_batch() is None - and int(worker_state.get("prepare_inflight", 0)) <= 0 - and int(worker_state.get("finalize_inflight", 0)) <= 0 - and int(worker_state.get("finalize_pending", 0)) <= 0 - ) - - def _record_engine_job_done(self, request_id: str) -> None: - self.engine_job_registry.mark_finished_and_remove(request_id) - self.scheduler_worker.record_external_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - completion_bridge = self.scheduler_worker.completion_bridge - completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - completion_bridge.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), - on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), - ) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - completion_bridge = self.scheduler_worker.completion_bridge - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - completion_bridge.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), - ) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for job in jobs: - job.prefill_ms += delta_ms - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - activate_request_ids: List[str] = [] - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] - self._enqueue_worker_finished_for_finalize(tasks) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.engine_decode_runtime_owner.refresh_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - if self.scheduler_worker.is_engine_decode_control_enabled(): - return - self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.engine_policy_arbiter.snapshot_state() - - def _notify_engine_arbiter(self) -> None: - self.engine_policy_arbiter.notify() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.engine_decode_runtime_owner.enqueue_pending_job(job) - self._notify_engine_arbiter() - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - if queue_name == "prepare": - return self.engine_prepare_queue_owner.peek_oldest_age_ms("enqueue_time") - elif queue_name == "finalize": - return self.engine_finalize_queue_owner.peek_oldest_age_ms("enqueued_time") - elif queue_name == "decode_runtime_pending": - return self.engine_decode_runtime_owner.pending_age_ms() - else: - return self.engine_dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") - - def _engine_has_pending_work(self) -> bool: - if self.scheduler_worker.is_engine_decode_control_enabled(): - if self.engine_decode_runtime_owner.has_pending_jobs(): - return True - if self.scheduler_worker.is_engine_decode_control_enabled() and self._snapshot_engine_decode_runtime_state().get("active_request_count", 0) > 0: - return True - if self.engine_prepare_queue_owner.has_items(): - return True - if self.engine_finalize_queue_owner.has_items(): - return True - return self.engine_dispatch_queue_owner.has_items() - - @staticmethod - def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: - if future.done(): - return - future.set_exception(error) - - def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - @staticmethod - def _resolve_prepare_future( - future: asyncio.Future, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if future.done(): - return - future.set_result(payload) - - def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - def _notify_prepare_result( - self, - task: EngineGpuPrepareTask, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) - except RuntimeError: - pass - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - if engine_request_id not in [None, ""]: - self._update_request_state( - str(engine_request_id), - EngineStatus.GPU_PREPARING, - { - "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), - "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), - "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), - "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), - }, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - task = EngineGpuPrepareTask( - request_id=spec.request_id, - cpu_stage=cpu_stage, - done_loop=loop, - done_future=done_future, - engine_request_id=engine_request_id or spec.request_id, - enqueue_time=time.perf_counter(), - ) - self.engine_prepare_queue_owner.enqueue(task) - self._notify_engine_arbiter() - state, prepare_exec_started_at, prepare_exec_finished_at = await done_future - return state, prepare_exec_started_at, prepare_exec_finished_at - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - for task in tasks: - job = self._get_engine_job(task.request_id) - if job is not None: - self._update_request_state( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": task.item.finish_reason, - "semantic_len": int(task.item.semantic_tokens.shape[0]), - "finish_idx": int(task.item.finish_idx), - }, - ) - self.engine_finalize_queue_owner.enqueue_many(tasks) - self._notify_engine_arbiter() - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - finalize_policy = self.scheduler_worker.get_finalize_batch_policy() - return self.engine_finalize_queue_owner.take_finalize_batch( - finalize_mode=str(finalize_policy.get("finalize_mode", "async")), - batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), - batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), - use_vocoder=bool(self.tts.configs.use_vocoder), - ) - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - task = EngineDispatchTask( - request_id=state.request_id, - state=state, - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id or state.request_id, - timeout_sec=timeout_sec, - enqueue_time=time.perf_counter(), - ) - self.engine_dispatch_queue_owner.enqueue(task) - self._notify_engine_arbiter() - self._merge_request_state_profile( - task.engine_request_id or task.request_id, - { - "engine_dispatch_queue_depth_on_enqueue": int(self._snapshot_engine_dispatch_state()["waiting_count"]), - }, - ) - return task - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() - self.engine_dispatch_last_snapshot = dict(policy_snapshot) - return stage, reason, policy_snapshot, worker_state - - def _run_engine_prepare_once(self) -> bool: - task = self.engine_prepare_queue_owner.pop_left() - if task is None: - return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) - if task.engine_request_id not in [None, ""]: - self._merge_request_state_profile( - str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, - ) - self.engine_prepare_queue_owner.mark_completed(1) - self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self._fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True - - def _run_engine_finalize_once(self) -> bool: - tasks = self._take_engine_finalize_batch_nonblocking() - if not tasks: - return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: - job = self._get_engine_job(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return False - now = time.perf_counter() - for task in tasks: - job = self._get_engine_job(task.request_id) - if job is not None: - job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) - for job, item in jobs_and_items: - self._update_request_state( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) - for job, _ in jobs_and_items: - job.synth_ms += float(synth_ms) - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._fail_engine_jobs([task.request_id for task in tasks], str(exc)) - finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.engine_finalize_queue_owner.mark_completed(len(tasks), notify=True) - return True - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - if not bool(policy_snapshot.get("allowed", True)): - return False - dispatch_task = self.engine_dispatch_queue_owner.pop_left() - if dispatch_task is None: - return False - dispatched_at = time.perf_counter() - dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) - dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_policy_snapshot = dict(policy_snapshot) - try: - worker_job = self.scheduler_worker.submit( - state=dispatch_task.state, - speed_factor=dispatch_task.speed_factor, - sample_steps=dispatch_task.sample_steps, - media_type=dispatch_task.media_type, - prepare_wall_ms=dispatch_task.prepare_wall_ms, - prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, - done_loop=dispatch_task.done_loop, - done_future=dispatch_task.done_future, - engine_request_id=dispatch_task.engine_request_id, - timeout_sec=dispatch_task.timeout_sec, - skip_capacity_wait=True, - admission_wait_ms_override=0.0, - admission_snapshot_override=dict(worker_state), - engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, - engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, - enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), - ) - dispatch_task.worker_job = worker_job - self._register_engine_job(worker_job) - if self.scheduler_worker.is_engine_decode_control_enabled(): - self._enqueue_engine_decode_pending_job(worker_job) - self.engine_dispatch_queue_owner.mark_completed(1) - return True - except Exception as exc: - dispatch_task.error = str(exc) - self._fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) - self._notify_dispatch_error(dispatch_task, exc) - return True - - def _run_engine_decode_runtime_once(self) -> bool: - if not self.scheduler_worker.is_engine_decode_control_enabled(): - return False - runtime_state = self._snapshot_engine_decode_runtime_state() - pending_jobs = self._take_engine_decode_pending_jobs_nonblocking( - wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 - ) - result = self.scheduler_worker.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=self.engine_decode_runtime_owner.get_active_batch(), - external_bookkeeping=True, - ) - prefill_phase = dict(result.get("prefill_phase") or {}) - if prefill_phase.get("error"): - self._fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) - else: - prefill_jobs = list(prefill_phase.get("pending_jobs") or []) - self._add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) - self._add_engine_merge_time( - [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), - float(prefill_phase.get("merge_elapsed_s", 0.0)), - ) - self._enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) - decode_phase = dict(result.get("decode_phase") or {}) - if decode_phase.get("error"): - self._fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) - else: - self._add_engine_decode_time( - list(decode_phase.get("request_ids") or []), - float(decode_phase.get("decode_elapsed_s", 0.0)), - ) - self._enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) - self.engine_decode_runtime_owner.set_active_batch(result.get("active_batch")) - if result.get("executed", False): - self._refresh_engine_decode_runtime_state("engine_decode_cycle") - return bool(result.get("executed", False)) - - def _run_engine_arbiter_loop(self) -> None: - while True: - if not self._engine_has_pending_work(): - self._mark_arbiter_tick(stage="idle", reason="no_pending_work", policy_allowed=True) - self.engine_policy_arbiter.wait() - continue - stage, reason, policy_snapshot, worker_state = self._select_engine_stage() - policy_allowed = bool(policy_snapshot.get("allowed", True)) - executed = False - if stage == "prepare": - executed = self._run_engine_prepare_once() - elif stage == "finalize": - executed = self._run_engine_finalize_once() - elif stage == "decode_dispatch": - executed = self._run_engine_dispatch_once(policy_snapshot, worker_state) - elif stage == "decode_runtime": - executed = self._run_engine_decode_runtime_once() - if not executed: - self._mark_arbiter_tick(stage="idle", reason=f"{stage}_not_ready", policy_allowed=policy_allowed) - self.engine_policy_arbiter.wait() - continue - self._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.complete(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.request_registry.fail(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.request_registry.snapshot() - - @staticmethod - def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: - if component is None or not hasattr(component, "snapshot"): - return None - try: - return dict(component.snapshot()) - except Exception: - return None - - def _build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state) - - def _build_engine_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state) - - async def _wait_for_engine_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - return await self.engine_policy_arbiter.wait_for_policy_admission( - request_id=request_id, - timeout_sec=timeout_sec, - ) - - def _build_stage_summary( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - counters = self._build_stage_counters(request_registry, worker_state) - bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) - ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) - text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) - - return { - **counters, - "engine_drained": bool(self._is_engine_drained()), - "admission_config": { - "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), - "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), - }, - "engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state), - "engine_arbiter_state": self._snapshot_engine_arbiter_state(), - "engine_decode_runtime_state": self._snapshot_engine_decode_runtime_state(), - "engine_job_registry": self._snapshot_engine_job_registry(), - "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), - "engine_prepare_state": self._snapshot_engine_prepare_state(), - "engine_finalize_state": self._snapshot_engine_finalize_state(), - "engine_dispatcher_state": self._snapshot_engine_dispatch_state(), - "active_batch": dict(worker_state.get("active_batch") or {}), - "prepare_state": dict(worker_state.get("prepare_state") or {}), - "bert_batch_worker_state": bert_worker_state, - "ref_semantic_worker_state": ref_semantic_worker_state, - "text_preprocessor_state": text_preprocessor_state, - } - - def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - return self.request_registry.collect_summaries(request_ids) - - def _has_active_request(self, request_id: str) -> bool: - return self.request_registry.has_active(request_id) - - @staticmethod - def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: - text = payload.get("text") - prompt_text = payload.get("prompt_text") - return { - "text_len": 0 if text is None else len(str(text)), - "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), - "text_lang": payload.get("text_lang"), - "prompt_lang": payload.get("prompt_lang"), - "ref_audio_path": payload.get("ref_audio_path"), - } - - @staticmethod - def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: - total = 0.0 - for item in items: - value = item.get(key, 0.0) - if isinstance(value, (int, float)): - total += float(value) - return total - - def _build_direct_segment_trace( - self, - segment_texts: Sequence[str], - prepare_profiles: Sequence[Dict[str, Any]], - worker_profiles: Sequence[Dict[str, Any]], - ) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - for index, segment_text in enumerate(segment_texts): - prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} - worker_item = worker_profiles[index] if index < len(worker_profiles) else {} - prepare_profile = dict(prepare_item.get("prepare_profile", {})) - results.append( - { - "segment_index": index, - "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), - "text_len": len(str(segment_text)), - "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), - "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), - "prepare_engine_gpu_queue_wait_ms": float( - dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) - ), - "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), - "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), - "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), - "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), - "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), - "merge_ms": float(worker_item.get("merge_ms", 0.0)), - "decode_ms": float(worker_item.get("decode_ms", 0.0)), - "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), - "synth_ms": float(worker_item.get("synth_ms", 0.0)), - "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), - "decode_steps": int(worker_item.get("decode_steps", 0)), - "semantic_len": int(worker_item.get("semantic_len", 0)), - "finish_reason": worker_item.get("finish_reason"), - "norm_text": prepare_profile.get("norm_text"), - } - ) - return results - - def _build_direct_scheduler_profile( - self, - *, - backend: str, - request_start: float, - response_ready_at: float, - audio_bytes: int, - sample_rate: int, - segment_texts: Sequence[str], - prepare_profiles: Sequence[Dict[str, Any]], - worker_profiles: Sequence[Dict[str, Any]], - pack_ms: float, - response_overhead_ms: float, - ) -> Dict[str, Any]: - segment_trace = self._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) - prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") - prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") - engine_policy_wait_ms = self._sum_profile_field(prepare_profiles, "engine_policy_wait_ms") - engine_dispatch_wait_ms = self._sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") - decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") - queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") - prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") - merge_ms = self._sum_profile_field(worker_profiles, "merge_ms") - decode_ms = self._sum_profile_field(worker_profiles, "decode_ms") - finalize_wait_ms = self._sum_profile_field(worker_profiles, "finalize_wait_ms") - synth_ms = self._sum_profile_field(worker_profiles, "synth_ms") - worker_total_ms = self._sum_profile_field(worker_profiles, "worker_total_ms") - decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) - semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) - request_other_ms = max( - 0.0, - request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, - ) - return { - "backend": backend, - "backend_mode": backend, - "segment_count": len(segment_texts), - "sample_rate": int(sample_rate), - "audio_bytes": int(audio_bytes), - "request_total_ms": request_total_ms, - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "engine_policy_wait_ms": engine_policy_wait_ms, - "engine_dispatch_wait_ms": engine_dispatch_wait_ms, - "decode_admission_wait_ms": decode_admission_wait_ms, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": prefill_ms, - "merge_ms": merge_ms, - "decode_ms": decode_ms, - "finalize_wait_ms": finalize_wait_ms, - "synth_ms": synth_ms, - "pack_ms": pack_ms, - "response_overhead_ms": response_overhead_ms, - "worker_total_ms": worker_total_ms, - "request_other_ms": request_other_ms, - "decode_steps": decode_steps, - "semantic_len": semantic_len, - "prepare_segments": list(prepare_profiles), - "worker_segments": list(worker_profiles), - "segment_trace": segment_trace, - "prepare_aggregate": self._aggregate_numeric_dicts(prepare_profile_dicts), - } - - def _build_legacy_direct_profile( - self, - *, - backend: str, - fallback_reason: str | None, - request_start: float, - finished_at: float, - sample_rate: int | None = None, - audio_bytes: int = 0, - pack_ms: float = 0.0, - chunk_count: int = 0, - stream_total_bytes: int = 0, - first_chunk_ms: float | None = None, - ) -> Dict[str, Any]: - request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) - legacy_infer_ms = max(0.0, request_total_ms - pack_ms) - return { - "backend": backend, - "backend_mode": backend, - "fallback_reason": fallback_reason, - "request_total_ms": request_total_ms, - "prepare_ms": 0.0, - "queue_wait_ms": 0.0, - "prefill_ms": 0.0, - "merge_ms": 0.0, - "decode_ms": 0.0, - "finalize_wait_ms": 0.0, - "synth_ms": 0.0, - "pack_ms": pack_ms, - "worker_total_ms": legacy_infer_ms, - "request_other_ms": 0.0, - "legacy_infer_ms": legacy_infer_ms, - "sample_rate": int(sample_rate) if sample_rate is not None else None, - "audio_bytes": int(audio_bytes), - "chunk_count": int(chunk_count), - "stream_total_bytes": int(stream_total_bytes), - "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), - } - - def _build_scheduler_submit_profile( - self, - *, - backend: str, - request_start: float, - response_ready_at: float, - audio_bytes: int, - sample_rate: int, - prepare_spec_build_ms: float, - prepare_wall_ms: float, - prepare_executor_queue_ms: float, - prepare_executor_run_ms: float, - prepare_profile_total_ms: float, - prepare_profile_wall_ms: float, - prepare_other_ms: float, - engine_policy_wait_ms: float, - api_after_prepare_ms: float, - api_wait_result_ms: float, - pack_ms: float, - response_overhead_ms: float, - worker_profile: Dict[str, Any], - ) -> Dict[str, Any]: - worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - request_other_ms = max( - 0.0, - request_total_ms - - prepare_wall_ms - - engine_policy_wait_ms - - api_after_prepare_ms - - worker_total_ms - - api_wait_result_ms - - pack_ms, - ) - result = { - "backend": backend, - "backend_mode": backend, - "audio_bytes": int(audio_bytes), - "sample_rate": int(sample_rate), - "prepare_spec_build_ms": prepare_spec_build_ms, - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_executor_queue_ms": prepare_executor_queue_ms, - "prepare_executor_run_ms": prepare_executor_run_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile_wall_ms": prepare_profile_wall_ms, - "prepare_other_ms": prepare_other_ms, - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "api_after_prepare_ms": api_after_prepare_ms, - "api_wait_result_ms": api_wait_result_ms, - "pack_ms": pack_ms, - "response_overhead_ms": response_overhead_ms, - "request_total_ms": request_total_ms, - "request_other_ms": request_other_ms, - } - result.update({key: value for key, value in worker_profile.items()}) - return result - - @staticmethod - def _format_ms_header(value: Any) -> str: - return f"{float(value):.3f}" - - def _build_scheduler_submit_headers( - self, - *, - request_id: str, - media_type: str, - sample_rate: int, - profile: Dict[str, Any], - ) -> Dict[str, str]: - prepare_profile = dict(profile.get("prepare_profile", {})) - headers = { - "X-Request-Id": request_id, - "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), - "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), - "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), - "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), - "X-Engine-Policy-Wait-Ms": self._format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), - "X-Engine-Dispatch-Wait-Ms": self._format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), - "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), - "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), - "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), - "X-Prepare-Executor-Queue-Ms": self._format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), - "X-Prepare-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), - "X-Prepare-Executor-Run-Ms": self._format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), - "X-Prepare-Profile-Total-Ms": self._format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), - "X-Prepare-Profile-Wall-Ms": self._format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), - "X-Prepare-Other-Ms": self._format_ms_header(profile.get("prepare_other_ms", 0.0)), - "X-Api-After-Prepare-Ms": self._format_ms_header(profile.get("api_after_prepare_ms", 0.0)), - "X-Prefill-Ms": self._format_ms_header(profile.get("prefill_ms", 0.0)), - "X-Merge-Ms": self._format_ms_header(profile.get("merge_ms", 0.0)), - "X-Decode-Ms": self._format_ms_header(profile.get("decode_ms", 0.0)), - "X-Finalize-Wait-Ms": self._format_ms_header(profile.get("finalize_wait_ms", 0.0)), - "X-Synth-Ms": self._format_ms_header(profile.get("synth_ms", 0.0)), - "X-Worker-Residual-Ms": self._format_ms_header(profile.get("worker_residual_ms", 0.0)), - "X-Worker-Other-Ms": self._format_ms_header(profile.get("worker_other_ms", 0.0)), - "X-Pack-Ms": self._format_ms_header(profile.get("pack_ms", 0.0)), - "X-Worker-Total-Ms": self._format_ms_header(profile.get("worker_total_ms", 0.0)), - "X-Api-Wait-Result-Ms": self._format_ms_header(profile.get("api_wait_result_ms", 0.0)), - "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), - "X-Sample-Rate": str(int(sample_rate)), - "X-Response-Overhead-Ms": self._format_ms_header(profile.get("response_overhead_ms", 0.0)), - "X-Request-Other-Ms": self._format_ms_header(profile.get("request_other_ms", 0.0)), - "X-Request-Total-Ms": self._format_ms_header(profile.get("request_total_ms", 0.0)), - } - headers.update( - { - "X-Prepare-Prompt-Text-Ms": self._format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), - "X-Prepare-Target-Text-Ms": self._format_ms_header(prepare_profile.get("text_features_ms", 0.0)), - "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), - "X-Prepare-Target-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), - "X-Prepare-Prompt-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), - "X-Prepare-Target-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), - "X-Prepare-Prompt-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), - "X-Prepare-Target-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), - "X-Prepare-Prompt-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), - "X-Prepare-Target-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), - "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Prompt-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), - "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), - "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), - "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), - "X-Prepare-Engine-GPU-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), - "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), - "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), - "X-Prepare-Prompt-Semantic-CPU-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), - "X-Prepare-Ref-Spec-Ms": self._format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), - "X-Prepare-Ref-Spec-Wait-Ms": self._format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), - "X-Prepare-Ref-Bundle-Ms": self._format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), - "X-Prepare-Tensorize-Ms": self._format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), - "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), - "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), - } - ) - return headers - - def _build_scheduler_debug_request_profile( - self, - *, - state: T2SRequestState, - item: T2SFinishedItem, - batch_request_count: int, - prepare_batch_wall_ms: float, - decode_batch_wall_ms: float, - batch_request_total_ms: float, - ) -> Dict[str, Any]: - prepare_profile = dict(state.prepare_profile) - prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) - return { - "backend": "scheduler_debug", - "backend_mode": "scheduler_debug", - "batch_request_count": int(batch_request_count), - "batch_prepare_wall_ms": float(prepare_batch_wall_ms), - "batch_decode_wall_ms": float(decode_batch_wall_ms), - "batch_request_total_ms": float(batch_request_total_ms), - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), - "prepare_profile": prepare_profile, - "decode_steps": int(item.finish_idx), - "finish_idx": int(item.finish_idx), - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_reason": item.finish_reason, - "norm_text": state.norm_text, - "norm_prompt_text": state.norm_prompt_text, - } - - @staticmethod - def _build_scheduler_debug_batch_profile( - *, - request_count: int, - max_steps: int, - prepare_batch_wall_ms: float, - decode_batch_wall_ms: float, - request_total_ms: float, - finished_items: Sequence[T2SFinishedItem], - ) -> Dict[str, Any]: - finish_reason_counts: Dict[str, int] = {} - total_semantic_len = 0 - for item in finished_items: - finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 - total_semantic_len += int(item.semantic_tokens.shape[0]) - return { - "request_count": int(request_count), - "max_steps": int(max_steps), - "prepare_batch_wall_ms": float(prepare_batch_wall_ms), - "decode_batch_wall_ms": float(decode_batch_wall_ms), - "request_total_ms": float(request_total_ms), - "total_semantic_len": int(total_semantic_len), - "finish_reason_counts": finish_reason_counts, - } - - def _normalize_lang(self, value: str | None) -> str | None: - if value in [None, ""]: - return value - return str(value).lower() - - @staticmethod - def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: - totals: Dict[str, float] = {} - for item in items: - for key, value in item.items(): - if isinstance(value, (int, float)): - totals[key] = totals.get(key, 0.0) + float(value) - return totals - - def _apply_default_reference(self, req: dict) -> dict: - normalized = dict(req) - default_ref = self.reference_registry.get_default() - if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: - normalized["ref_audio_path"] = default_ref.ref_audio_path - if "text_lang" in normalized: - normalized["text_lang"] = self._normalize_lang(normalized.get("text_lang")) - if "prompt_lang" in normalized: - normalized["prompt_lang"] = self._normalize_lang(normalized.get("prompt_lang")) - return normalized - - def check_params(self, req: dict) -> Optional[str]: - text = req.get("text", "") - text_lang = req.get("text_lang", "") - ref_audio_path = req.get("ref_audio_path", "") - media_type = req.get("media_type", "wav") - prompt_lang = req.get("prompt_lang", "") - text_split_method = req.get("text_split_method", "cut5") - - if ref_audio_path in [None, ""]: - return "ref_audio_path is required" - if text in [None, ""]: - return "text is required" - if text_lang in [None, ""]: - return "text_lang is required" - if text_lang.lower() not in self.tts.configs.languages: - return f"text_lang: {text_lang} is not supported in version {self.tts.configs.version}" - if prompt_lang in [None, ""]: - return "prompt_lang is required" - if prompt_lang.lower() not in self.tts.configs.languages: - return f"prompt_lang: {prompt_lang} is not supported in version {self.tts.configs.version}" - if media_type not in ["wav", "raw", "ogg", "aac"]: - return f"media_type: {media_type} is not supported" - if text_split_method not in self.cut_method_names: - return f"text_split_method:{text_split_method} is not supported" - return None - - @staticmethod - def _base_request_defaults() -> Dict[str, Any]: - return { - "request_id": None, - "text": None, - "text_lang": None, - "ref_audio_path": None, - "aux_ref_audio_paths": None, - "prompt_text": "", - "prompt_lang": None, - "top_k": 15, - "top_p": 1.0, - "temperature": 1.0, - "text_split_method": "cut5", - "batch_size": 1, - "batch_threshold": 0.75, - "speed_factor": 1.0, - "split_bucket": False, - "fragment_interval": 0.3, - "seed": -1, - "media_type": "wav", - "streaming_mode": False, - "return_fragment": False, - "fixed_length_chunk": False, - "response_streaming": False, - "parallel_infer": False, - "repetition_penalty": 1.35, - "sample_steps": 32, - "super_sampling": False, - "overlap_length": 2, - "min_chunk_length": 16, - "early_stop_num": -1, - "ready_step": 0, - "timeout_sec": None, - } - - def _normalize_engine_request( - self, - payload: dict | NormalizedEngineRequest, - *, - request_id: str | None = None, - normalize_streaming: bool = False, - error_prefix: str = "request 参数非法: ", - ) -> NormalizedEngineRequest: - if isinstance(payload, NormalizedEngineRequest): - normalized_payload = payload.to_payload() - else: - normalized_payload = self._base_request_defaults() - normalized_payload.update(dict(payload)) - if request_id not in [None, ""]: - normalized_payload["request_id"] = str(request_id) - elif normalized_payload.get("request_id") in [None, ""]: - raise ValueError("request_id is required after normalization") - normalized_payload = self._apply_default_reference(normalized_payload) - if normalize_streaming: - normalized_payload = self._normalize_streaming_mode(normalized_payload) - error = self.check_params(normalized_payload) - if error is not None: - raise ValueError(f"{error_prefix}{error}") - timeout_sec = normalized_payload.get("timeout_sec") - if timeout_sec in [None, ""]: - parsed_timeout = None - else: - parsed_timeout = float(timeout_sec) - aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") - if aux_ref_audio_paths in [None, "", []]: - normalized_aux_ref_audio_paths = None - else: - normalized_aux_ref_audio_paths = [str(item) for item in aux_ref_audio_paths] - return NormalizedEngineRequest( - request_id=str(normalized_payload["request_id"]), - text=str(normalized_payload["text"]), - text_lang=str(normalized_payload["text_lang"]), - ref_audio_path=str(normalized_payload["ref_audio_path"]), - prompt_lang=str(normalized_payload["prompt_lang"]), - prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), - aux_ref_audio_paths=normalized_aux_ref_audio_paths, - top_k=int(normalized_payload["top_k"]), - top_p=float(normalized_payload["top_p"]), - temperature=float(normalized_payload["temperature"]), - repetition_penalty=float(normalized_payload["repetition_penalty"]), - early_stop_num=int(normalized_payload.get("early_stop_num", -1)), - ready_step=int(normalized_payload.get("ready_step", 0)), - text_split_method=str(normalized_payload["text_split_method"]), - batch_size=int(normalized_payload["batch_size"]), - batch_threshold=float(normalized_payload["batch_threshold"]), - split_bucket=bool(normalized_payload["split_bucket"]), - speed_factor=float(normalized_payload["speed_factor"]), - fragment_interval=float(normalized_payload["fragment_interval"]), - seed=int(normalized_payload["seed"]), - media_type=str(normalized_payload["media_type"]), - streaming_mode=normalized_payload["streaming_mode"], - return_fragment=bool(normalized_payload.get("return_fragment", False)), - fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), - response_streaming=bool(normalized_payload.get("response_streaming", False)), - parallel_infer=bool(normalized_payload["parallel_infer"]), - sample_steps=int(normalized_payload["sample_steps"]), - super_sampling=bool(normalized_payload["super_sampling"]), - overlap_length=int(normalized_payload["overlap_length"]), - min_chunk_length=int(normalized_payload["min_chunk_length"]), - timeout_sec=parsed_timeout, - ) - - @staticmethod - def _normalize_streaming_mode(req: dict) -> dict: - normalized = dict(req) - streaming_mode = normalized.get("streaming_mode", False) - return_fragment = normalized.get("return_fragment", False) - if streaming_mode is False: - normalized["streaming_mode"] = False - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 0: - normalized["streaming_mode"] = False - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 1 or streaming_mode is True: - normalized["streaming_mode"] = False - normalized["return_fragment"] = True - normalized["fixed_length_chunk"] = False - elif streaming_mode == 2: - normalized["streaming_mode"] = True - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 3: - normalized["streaming_mode"] = True - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = True - else: - raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") - normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) - return normalized - - @staticmethod - def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: - return aux_ref_audio_paths not in [None, [], ()] - - def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: - if normalized.response_streaming: - if normalized.return_fragment or normalized.fixed_length_chunk: - return "legacy_direct_fragment", "fragment_streaming_mode" - return "legacy_direct_streaming", "streaming_mode" - if self._is_aux_ref_enabled(normalized.aux_ref_audio_paths): - return "legacy_direct_aux_ref", "aux_ref_audio_paths" - if normalized.super_sampling: - return "legacy_direct_super_sampling", "super_sampling" - if normalized.prompt_text in [None, ""]: - return "legacy_direct_missing_prompt", "missing_prompt_text" - return "scheduler_v1_direct", None - - def _iter_legacy_direct_tts_bytes( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> Generator[bytes, None, None]: - payload = normalized.to_payload() - media_type = normalized.media_type - request_id = normalized.request_id - request_start = time.perf_counter() - chunk_count = 0 - stream_total_bytes = 0 - first_chunk_ms: float | None = None - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - try: - with self.direct_tts_lock: - tts_generator = self.tts.run(payload) - first_chunk = True - current_media_type = media_type - for sr, chunk in tts_generator: - if first_chunk: - first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) - self._update_request_state( - request_id, - EngineStatus.STREAMING, - { - "backend": backend, - "backend_mode": backend, - "fallback_reason": fallback_reason, - "sample_rate": int(sr), - }, - ) - if first_chunk and media_type == "wav": - header = wave_header_chunk(sample_rate=sr) - chunk_count += 1 - stream_total_bytes += len(header) - yield header - current_media_type = "raw" - first_chunk = False - elif first_chunk: - first_chunk = False - packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() - chunk_count += 1 - stream_total_bytes += len(packed_chunk) - yield packed_chunk - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - self._complete_request_state( - request_id, - dict( - self._build_legacy_direct_profile( - backend=backend, - fallback_reason=fallback_reason, - request_start=request_start, - finished_at=time.perf_counter(), - audio_bytes=stream_total_bytes, - chunk_count=chunk_count, - stream_total_bytes=stream_total_bytes, - first_chunk_ms=first_chunk_ms, - ), - streaming_completed=True, - ), - ) - - def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: - if isinstance(req, NormalizedEngineRequest): - normalized = req - else: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - ) - backend, _ = self._select_direct_backend(normalized) - return backend == "scheduler_v1_direct" - - def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: - payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized - return self.tts.text_preprocessor.pre_seg_text( - str(payload["text"]), - str(payload["text_lang"]), - str(payload.get("text_split_method", "cut5")), - ) - - def _build_segment_request( - self, - normalized: NormalizedEngineRequest, - *, - request_id: str, - text: str, - ) -> NormalizedEngineRequest: - payload = normalized.to_payload() - payload["request_id"] = request_id - payload["text"] = text - payload["streaming_mode"] = False - payload["return_fragment"] = False - payload["fixed_length_chunk"] = False - payload["response_streaming"] = False - return self._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") - - async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: - request_start = time.perf_counter() - request_id = normalized.request_id - media_type = normalized.media_type - segment_texts = self._segment_direct_text(normalized) - if not segment_texts: - raise ValueError("text preprocessing returned no valid segments") - self._update_request_state( - request_id, - EngineStatus.CPU_PREPARING, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, - ) - segment_specs: List[SchedulerRequestSpec] = [] - for segment_index, segment_text in enumerate(segment_texts): - segment_request = self._build_segment_request( - normalized, - request_id=f"{request_id}_seg_{segment_index:03d}", - text=segment_text, - ) - segment_specs.append(self.build_scheduler_submit_spec(segment_request)) - - prepared_items = await asyncio.gather( - *[ - self._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=time.perf_counter(), - engine_request_id=None, - ) - for spec in segment_specs - ] - ) - prepare_profiles: List[Dict[str, Any]] = [] - loop = asyncio.get_running_loop() - done_futures: List[asyncio.Future] = [] - self._update_request_state( - request_id, - EngineStatus.READY_FOR_PREFILL, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, - ) - for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): - prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) - prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_profiles.append( - { - "request_id": spec.request_id, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile": dict(state.prepare_profile), - } - ) - done_future = loop.create_future() - done_futures.append(done_future) - await self._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=None, - timeout_sec=normalized.timeout_sec, - ) - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, - ) - timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) - jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) - for profile_item, job in zip(prepare_profiles, jobs): - profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) - profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) - self._merge_request_state_profile( - request_id, - { - "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), - "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), - "prepare_aggregate": self._aggregate_numeric_dicts( - [item["prepare_profile"] for item in prepare_profiles] - ), - }, - ) - - sample_rate: int | None = None - audio_parts: List[np.ndarray] = [] - worker_profiles: List[Dict[str, Any]] = [] - fragment_interval = float(normalized.fragment_interval) - silence_chunk: Optional[np.ndarray] = None - for job in jobs: - if job.error is not None: - raise RuntimeError(job.error) - if job.audio_data is None or job.sample_rate is None or job.result is None: - raise RuntimeError(f"{job.request_id} finished without audio result") - if sample_rate is None: - sample_rate = int(job.sample_rate) - silence_samples = int(fragment_interval * float(sample_rate)) - if silence_samples > 0: - silence_chunk = np.zeros(silence_samples, dtype=np.int16) - elif int(job.sample_rate) != sample_rate: - raise RuntimeError("segment sample rate mismatch") - audio_parts.append(job.audio_data) - if silence_chunk is not None: - audio_parts.append(silence_chunk.copy()) - worker_profiles.append(dict(job.result)) - if sample_rate is None or not audio_parts: - raise RuntimeError("direct scheduler backend produced no audio") - self._update_request_state( - request_id, - EngineStatus.FINALIZING, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, - ) - merged_audio = np.concatenate(audio_parts, axis=0) - pack_start = time.perf_counter() - audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() - pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) - direct_profile = self._build_direct_scheduler_profile( - backend="scheduler_v1_direct", - request_start=request_start, - response_ready_at=time.perf_counter(), - audio_bytes=len(audio_bytes), - sample_rate=int(sample_rate), - segment_texts=segment_texts, - prepare_profiles=prepare_profiles, - worker_profiles=worker_profiles, - pack_ms=pack_ms, - response_overhead_ms=0.0, - ) - self._complete_request_state( - request_id, - dict(direct_profile, streaming_completed=False), - ) - return DirectTTSExecution( - media_type=media_type, - streaming=False, - audio_bytes=audio_bytes, - request_id=request_id, - ) - - def _run_legacy_direct_tts_blocking( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - normalized_payload = normalized.to_payload() - request_id = normalized.request_id - media_type = normalized.media_type - request_start = time.perf_counter() - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - with self.direct_tts_lock: - tts_generator = self.tts.run(normalized_payload) - try: - sr, audio_data = next(tts_generator) - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - self._update_request_state( - request_id, - EngineStatus.FINALIZING, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - pack_start = time.perf_counter() - packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) - self._complete_request_state( - request_id, - dict( - self._build_legacy_direct_profile( - backend=backend, - fallback_reason=fallback_reason, - request_start=request_start, - finished_at=time.perf_counter(), - sample_rate=int(sr), - audio_bytes=len(packed_audio), - pack_ms=pack_ms, - ), - streaming_completed=False, - ), - ) - return DirectTTSExecution( - media_type=media_type, - streaming=False, - audio_bytes=packed_audio, - request_id=request_id, - ) - - async def _run_direct_tts_via_legacy_backend( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - if normalized.response_streaming: - return DirectTTSExecution( - media_type=normalized.media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ), - request_id=normalized.request_id, - ) - return await asyncio.to_thread( - self._run_legacy_direct_tts_blocking, - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - error_prefix="", - ) - request_id = normalized.request_id - media_type = normalized.media_type - backend, fallback_reason = self._select_direct_backend(normalized) - self._register_request_state( - request_id=request_id, - api_mode="tts", - backend=backend, - media_type=media_type, - response_streaming=bool(normalized.response_streaming), - deadline_ts=( - time.perf_counter() + float(normalized.timeout_sec) - if normalized.timeout_sec is not None - else None - ), - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state( - request_id, - EngineStatus.VALIDATED, - { - "request_source": "direct_tts", - "selected_backend": backend, - "fallback_reason": fallback_reason, - }, - ) - if backend == "scheduler_v1_direct": - try: - return await self._run_direct_tts_via_scheduler(normalized) - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - return await self._run_direct_tts_via_legacy_backend( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - def run_direct_tts(self, req: dict) -> DirectTTSExecution: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - error_prefix="", - ) - request_id = normalized.request_id - media_type = normalized.media_type - backend, fallback_reason = self._select_direct_backend(normalized) - if not self._has_active_request(request_id): - self._register_request_state( - request_id=request_id, - api_mode="tts", - backend=backend, - media_type=media_type, - response_streaming=bool(normalized.response_streaming), - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state( - request_id, - EngineStatus.VALIDATED, - { - "request_source": "direct_tts", - "selected_backend": backend, - "fallback_reason": fallback_reason, - }, - ) - if backend != "scheduler_v1_direct": - if normalized.response_streaming: - return DirectTTSExecution( - media_type=media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ), - request_id=request_id, - ) - return self._run_legacy_direct_tts_blocking( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - normalized_payload = normalized.to_payload() - if normalized.response_streaming: - return DirectTTSExecution( - media_type=media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend="legacy_direct_sync_compat", - fallback_reason="sync_direct_compat", - ), - request_id=request_id, - ) - return self._run_legacy_direct_tts_blocking( - normalized, - backend="legacy_direct_sync_compat", - fallback_reason="sync_direct_compat", - ) - - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: - specs: List[SchedulerRequestSpec] = [] - for index, payload in enumerate(request_items): - normalized = self._normalize_engine_request( - payload, - request_id=str(payload.get("request_id") or f"req_{index:03d}"), - error_prefix=f"request[{index}] 参数非法: ", - ) - specs.append(normalized.to_scheduler_spec()) - return specs - - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: - normalized = self._normalize_engine_request( - payload, - request_id=( - payload.request_id - if isinstance(payload, NormalizedEngineRequest) - else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") - ), - ) - return normalized.to_scheduler_spec() - - @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: - return [ - { - "request_id": state.request_id, - "ready_step": int(state.ready_step), - "ref_audio_path": str(state.ref_audio_path), - "prompt_semantic_len": int(state.prompt_semantic.shape[0]), - "all_phone_len": int(state.all_phones.shape[0]), - "bert_len": int(state.all_bert_features.shape[-1]), - "norm_text": state.norm_text, - } - for state in states - ] - - @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: - return [ - { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - } - for item in items - ] - - async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: - request_start = time.perf_counter() - set_scheduler_seed(seed) - specs = self.build_scheduler_request_specs(request_items) - request_ids = [spec.request_id for spec in specs] - for spec in specs: - self._register_request_state( - request_id=spec.request_id, - api_mode="scheduler_debug", - backend="scheduler_debug", - media_type="wav", - response_streaming=False, - meta={ - "text_len": len(spec.text), - "prompt_text_len": len(spec.prompt_text), - "text_lang": spec.text_lang, - "prompt_lang": spec.prompt_lang, - "ref_audio_path": str(spec.ref_audio_path), - "ready_step": int(spec.ready_step), - }, - ) - self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) - self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) - prepare_started_at = time.perf_counter() - try: - states = await self.scheduler_worker.prepare_states_batch_async(specs) - except Exception as exc: - for request_id in request_ids: - self._fail_request_state(request_id, str(exc)) - raise - prepare_finished_at = time.perf_counter() - prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) - for state in states: - self._update_request_state( - state.request_id, - EngineStatus.ACTIVE_DECODE, - { - "prepare_profile": dict(state.prepare_profile), - "norm_text": state.norm_text, - "norm_prompt_text": state.norm_prompt_text, - }, - ) - decode_started_at = time.perf_counter() - try: - finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) - except Exception as exc: - for request_id in request_ids: - self._fail_request_state(request_id, str(exc)) - raise - decode_finished_at = time.perf_counter() - decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) - request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) - finished_map = {item.request_id: item for item in finished} - request_profiles: List[Dict[str, Any]] = [] - for state in states: - item = finished_map.get(state.request_id) - if item is None: - self._fail_request_state(state.request_id, "scheduler_debug finished without result") - continue - request_profile = self._build_scheduler_debug_request_profile( - state=state, - item=item, - batch_request_count=len(states), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - batch_request_total_ms=request_total_ms, - ) - request_profiles.append( - { - "request_id": state.request_id, - "profile": dict(request_profile), - } - ) - self._complete_request_state( - state.request_id, - dict(request_profile), - ) - return SchedulerDebugExecution( - payload={ - "message": "success", - "request_count": len(states), - "max_steps": int(max_steps), - "batch_profile": self._build_scheduler_debug_batch_profile( - request_count=len(states), - max_steps=int(max_steps), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - request_total_ms=request_total_ms, - finished_items=finished, - ), - "requests": self.summarize_scheduler_states(states), - "finished": self.summarize_scheduler_finished(finished), - "request_profiles": request_profiles, - "request_traces": self._collect_request_summaries(request_ids), - } - ) - - async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: - request_start = time.perf_counter() - prepare_start = request_start - normalized = self._normalize_engine_request( - payload, - request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), - ) - spec = self.build_scheduler_submit_spec(normalized) - deadline_ts = None - timeout_sec = normalized.timeout_sec - if timeout_sec is not None: - try: - deadline_ts = request_start + float(timeout_sec) - except Exception: - deadline_ts = None - self._register_request_state( - request_id=spec.request_id, - api_mode="scheduler_submit", - backend="scheduler_v1", - media_type=normalized.media_type, - response_streaming=False, - deadline_ts=deadline_ts, - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) - spec_ready_at = time.perf_counter() - prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) - self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = await self._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=spec_ready_at, - engine_request_id=spec.request_id, - ) - except Exception as exc: - self._fail_request_state(spec.request_id, str(exc)) - raise - prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) - prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) - prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) - prepare_profile = dict(state.prepare_profile) - prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) - self._update_request_state( - spec.request_id, - EngineStatus.READY_FOR_PREFILL, - { - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile": prepare_profile, - }, - ) - api_after_prepare_start = time.perf_counter() - loop = asyncio.get_running_loop() - done_future = loop.create_future() - await self._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=normalized.media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=spec.request_id, - timeout_sec=normalized.timeout_sec, - ) - api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) - try: - job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) - except Exception as exc: - self._fail_request_state(spec.request_id, str(exc)) - raise - wait_return_at = time.perf_counter() - if job.error is not None: - raise RuntimeError(job.error) - if job.audio_data is None or job.sample_rate is None or job.result is None: - self._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") - raise RuntimeError(f"{job.request_id} finished without audio result") - pack_start = time.perf_counter() - audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_end = time.perf_counter() - pack_ms = (pack_end - pack_start) * 1000.0 - api_wait_result_ms = 0.0 - if job.result_ready_time is not None: - api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) - response_ready_at = time.perf_counter() - response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) - submit_profile = self._build_scheduler_submit_profile( - backend="scheduler_v1", - request_start=request_start, - response_ready_at=response_ready_at, - audio_bytes=len(audio_data), - sample_rate=int(job.sample_rate), - prepare_spec_build_ms=prepare_spec_build_ms, - prepare_wall_ms=prepare_wall_ms, - prepare_executor_queue_ms=prepare_executor_queue_ms, - prepare_executor_run_ms=prepare_executor_run_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - prepare_profile_wall_ms=prepare_profile_wall_ms, - prepare_other_ms=prepare_other_ms, - engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), - api_after_prepare_ms=api_after_prepare_ms, - api_wait_result_ms=api_wait_result_ms, - pack_ms=pack_ms, - response_overhead_ms=response_overhead_ms, - worker_profile=dict(job.result or {}), - ) - headers = self._build_scheduler_submit_headers( - request_id=job.request_id, - media_type=job.media_type, - sample_rate=int(job.sample_rate), - profile=submit_profile, - ) - self._merge_request_state_profile( - spec.request_id, - dict(submit_profile, response_headers_emitted=True), - ) - return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) - - def get_scheduler_state(self) -> dict: - return self.scheduler_worker.snapshot() - - def get_runtime_state(self) -> dict: - model_state = self.model_registry.snapshot() - default_ref = self.reference_registry.get_default() - scheduler_state = self.get_scheduler_state() - request_registry = self._snapshot_request_registry() - engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state) - engine_arbiter_state = self._snapshot_engine_arbiter_state() - engine_decode_runtime_state = self._snapshot_engine_decode_runtime_state() - engine_job_registry = self._snapshot_engine_job_registry() - engine_prepare_state = self._snapshot_engine_prepare_state() - engine_finalize_state = self._snapshot_engine_finalize_state() - engine_dispatcher_state = self._snapshot_engine_dispatch_state() - engine_drained = self._is_engine_drained() - return { - "message": "success", - "default_reference": { - "ref_audio_path": default_ref.ref_audio_path, - "updated_at": default_ref.updated_at, - }, - "model_registry": { - "generation": model_state.generation, - "t2s_generation": model_state.t2s_generation, - "vits_generation": model_state.vits_generation, - "t2s_weights_path": model_state.t2s_weights_path, - "vits_weights_path": model_state.vits_weights_path, - "updated_at": model_state.updated_at, - }, - "worker_state": scheduler_state, - "engine_policy": engine_policy, - "engine_arbiter_state": engine_arbiter_state, - "engine_decode_runtime_state": engine_decode_runtime_state, - "engine_job_registry": engine_job_registry, - "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), - "engine_prepare_state": engine_prepare_state, - "engine_finalize_state": engine_finalize_state, - "engine_dispatcher_state": engine_dispatcher_state, - "engine_drained": bool(engine_drained), - "request_registry": request_registry, - "stage_summary": self._build_stage_summary(request_registry, scheduler_state), - } - - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: - if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec): - raise TimeoutError("scheduler worker did not drain before model reload") - - def set_refer_audio(self, refer_audio_path: str | None) -> dict: - if refer_audio_path in [None, ""]: - state = self.reference_registry.clear() - return {"message": "success", "default_ref_audio_path": state.ref_audio_path} - if not os.path.exists(str(refer_audio_path)): - raise FileNotFoundError(f"{refer_audio_path} not exists") - with self.management_lock: - with self.direct_tts_lock: - self.tts.set_ref_audio(str(refer_audio_path)) - state = self.reference_registry.set_default(str(refer_audio_path)) - return {"message": "success", "default_ref_audio_path": state.ref_audio_path} - - def set_gpt_weights(self, weights_path: str) -> dict: - if weights_path in ["", None]: - raise ValueError("gpt weight path is required") - with self.management_lock: - self._wait_for_safe_reload() - with self.direct_tts_lock: - self.tts.init_t2s_weights(weights_path) - self.tts.refresh_runtime_components() - state = self.model_registry.mark_t2s_reload(str(weights_path)) - return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation} - - def set_sovits_weights(self, weights_path: str) -> dict: - if weights_path in ["", None]: - raise ValueError("sovits weight path is required") - with self.management_lock: - self._wait_for_safe_reload() - with self.direct_tts_lock: - self.tts.init_vits_weights(weights_path) - self.tts.refresh_runtime_components() - state = self.model_registry.mark_vits_reload(str(weights_path)) - return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation} - - def handle_control(self, command: str) -> None: - if command == "restart": - if self.control_callbacks.restart is None: - os.execl(sys.executable, sys.executable, *sys.argv) - self.control_callbacks.restart() - return - if command == "exit": - if self.control_callbacks.exit is None: - os.kill(os.getpid(), signal.SIGTERM) - return - self.control_callbacks.exit() - return - raise ValueError(f"unsupported command: {command}") + EngineCompositionBuilder(self).build(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py new file mode 100644 index 00000000..bcc8bf0d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -0,0 +1,1399 @@ +from __future__ import annotations + +import asyncio +import time +import uuid +from io import BytesIO +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous +from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed, wave_header_chunk +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + DirectTTSExecution, + EngineStatus, + NormalizedEngineRequest, + SchedulerDebugExecution, + SchedulerPendingJob, + SchedulerSubmitExecution, +) + + +class EngineApiFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def tts(self): + return self.owner.tts + + @property + def cut_method_names(self): + return self.owner.cut_method_names + + @property + def reference_registry(self): + return self.owner.reference_registry + + @property + def direct_tts_lock(self): + return self.owner.direct_tts_lock + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ): + return self.owner._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.owner._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.owner._merge_request_state_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.owner._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.owner._fail_request_state(request_id, error) + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.owner._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ): + return await self.owner._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.owner.request_registry.collect_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.owner.request_registry.has_active(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + text = payload.get("text") + prompt_text = payload.get("prompt_text") + return { + "text_len": 0 if text is None else len(str(text)), + "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), + "text_lang": payload.get("text_lang"), + "prompt_lang": payload.get("prompt_lang"), + "ref_audio_path": payload.get("ref_audio_path"), + } + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + total = 0.0 + for item in items: + value = item.get(key, 0.0) + if isinstance(value, (int, float)): + total += float(value) + return total + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + results: List[Dict[str, Any]] = [] + for index, segment_text in enumerate(segment_texts): + prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} + worker_item = worker_profiles[index] if index < len(worker_profiles) else {} + prepare_profile = dict(prepare_item.get("prepare_profile", {})) + results.append( + { + "segment_index": index, + "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), + "text_len": len(str(segment_text)), + "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), + "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), + "prepare_engine_gpu_queue_wait_ms": float( + dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) + ), + "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), + "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), + "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), + "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), + "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), + "merge_ms": float(worker_item.get("merge_ms", 0.0)), + "decode_ms": float(worker_item.get("decode_ms", 0.0)), + "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), + "synth_ms": float(worker_item.get("synth_ms", 0.0)), + "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), + "decode_steps": int(worker_item.get("decode_steps", 0)), + "semantic_len": int(worker_item.get("semantic_len", 0)), + "finish_reason": worker_item.get("finish_reason"), + "norm_text": prepare_profile.get("norm_text"), + } + ) + return results + + def _build_direct_scheduler_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + pack_ms: float, + response_overhead_ms: float, + ) -> Dict[str, Any]: + segment_trace = self._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") + prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") + engine_policy_wait_ms = self._sum_profile_field(prepare_profiles, "engine_policy_wait_ms") + engine_dispatch_wait_ms = self._sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") + decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") + queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") + prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") + merge_ms = self._sum_profile_field(worker_profiles, "merge_ms") + decode_ms = self._sum_profile_field(worker_profiles, "decode_ms") + finalize_wait_ms = self._sum_profile_field(worker_profiles, "finalize_wait_ms") + synth_ms = self._sum_profile_field(worker_profiles, "synth_ms") + worker_total_ms = self._sum_profile_field(worker_profiles, "worker_total_ms") + decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) + semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, + ) + return { + "backend": backend, + "backend_mode": backend, + "segment_count": len(segment_texts), + "sample_rate": int(sample_rate), + "audio_bytes": int(audio_bytes), + "request_total_ms": request_total_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "engine_policy_wait_ms": engine_policy_wait_ms, + "engine_dispatch_wait_ms": engine_dispatch_wait_ms, + "decode_admission_wait_ms": decode_admission_wait_ms, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": prefill_ms, + "merge_ms": merge_ms, + "decode_ms": decode_ms, + "finalize_wait_ms": finalize_wait_ms, + "synth_ms": synth_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "worker_total_ms": worker_total_ms, + "request_other_ms": request_other_ms, + "decode_steps": decode_steps, + "semantic_len": semantic_len, + "prepare_segments": list(prepare_profiles), + "worker_segments": list(worker_profiles), + "segment_trace": segment_trace, + "prepare_aggregate": self._aggregate_numeric_dicts(prepare_profile_dicts), + } + + def _build_legacy_direct_profile( + self, + *, + backend: str, + fallback_reason: str | None, + request_start: float, + finished_at: float, + sample_rate: int | None = None, + audio_bytes: int = 0, + pack_ms: float = 0.0, + chunk_count: int = 0, + stream_total_bytes: int = 0, + first_chunk_ms: float | None = None, + ) -> Dict[str, Any]: + request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) + legacy_infer_ms = max(0.0, request_total_ms - pack_ms) + return { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "request_total_ms": request_total_ms, + "prepare_ms": 0.0, + "queue_wait_ms": 0.0, + "prefill_ms": 0.0, + "merge_ms": 0.0, + "decode_ms": 0.0, + "finalize_wait_ms": 0.0, + "synth_ms": 0.0, + "pack_ms": pack_ms, + "worker_total_ms": legacy_infer_ms, + "request_other_ms": 0.0, + "legacy_infer_ms": legacy_infer_ms, + "sample_rate": int(sample_rate) if sample_rate is not None else None, + "audio_bytes": int(audio_bytes), + "chunk_count": int(chunk_count), + "stream_total_bytes": int(stream_total_bytes), + "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), + } + + def _build_scheduler_submit_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + prepare_spec_build_ms: float, + prepare_wall_ms: float, + prepare_executor_queue_ms: float, + prepare_executor_run_ms: float, + prepare_profile_total_ms: float, + prepare_profile_wall_ms: float, + prepare_other_ms: float, + engine_policy_wait_ms: float, + api_after_prepare_ms: float, + api_wait_result_ms: float, + pack_ms: float, + response_overhead_ms: float, + worker_profile: Dict[str, Any], + ) -> Dict[str, Any]: + worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms + - prepare_wall_ms + - engine_policy_wait_ms + - api_after_prepare_ms + - worker_total_ms + - api_wait_result_ms + - pack_ms, + ) + result = { + "backend": backend, + "backend_mode": backend, + "audio_bytes": int(audio_bytes), + "sample_rate": int(sample_rate), + "prepare_spec_build_ms": prepare_spec_build_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_executor_queue_ms": prepare_executor_queue_ms, + "prepare_executor_run_ms": prepare_executor_run_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile_wall_ms": prepare_profile_wall_ms, + "prepare_other_ms": prepare_other_ms, + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "api_after_prepare_ms": api_after_prepare_ms, + "api_wait_result_ms": api_wait_result_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "request_total_ms": request_total_ms, + "request_other_ms": request_other_ms, + } + result.update({key: value for key, value in worker_profile.items()}) + return result + + @staticmethod + def _format_ms_header(value: Any) -> str: + return f"{float(value):.3f}" + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + prepare_profile = dict(profile.get("prepare_profile", {})) + headers = { + "X-Request-Id": request_id, + "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), + "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), + "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), + "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), + "X-Engine-Policy-Wait-Ms": self._format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), + "X-Engine-Dispatch-Wait-Ms": self._format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), + "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), + "X-Prepare-Executor-Queue-Ms": self._format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), + "X-Prepare-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), + "X-Prepare-Executor-Run-Ms": self._format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), + "X-Prepare-Profile-Total-Ms": self._format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), + "X-Prepare-Profile-Wall-Ms": self._format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), + "X-Prepare-Other-Ms": self._format_ms_header(profile.get("prepare_other_ms", 0.0)), + "X-Api-After-Prepare-Ms": self._format_ms_header(profile.get("api_after_prepare_ms", 0.0)), + "X-Prefill-Ms": self._format_ms_header(profile.get("prefill_ms", 0.0)), + "X-Merge-Ms": self._format_ms_header(profile.get("merge_ms", 0.0)), + "X-Decode-Ms": self._format_ms_header(profile.get("decode_ms", 0.0)), + "X-Finalize-Wait-Ms": self._format_ms_header(profile.get("finalize_wait_ms", 0.0)), + "X-Synth-Ms": self._format_ms_header(profile.get("synth_ms", 0.0)), + "X-Worker-Residual-Ms": self._format_ms_header(profile.get("worker_residual_ms", 0.0)), + "X-Worker-Other-Ms": self._format_ms_header(profile.get("worker_other_ms", 0.0)), + "X-Pack-Ms": self._format_ms_header(profile.get("pack_ms", 0.0)), + "X-Worker-Total-Ms": self._format_ms_header(profile.get("worker_total_ms", 0.0)), + "X-Api-Wait-Result-Ms": self._format_ms_header(profile.get("api_wait_result_ms", 0.0)), + "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), + "X-Sample-Rate": str(int(sample_rate)), + "X-Response-Overhead-Ms": self._format_ms_header(profile.get("response_overhead_ms", 0.0)), + "X-Request-Other-Ms": self._format_ms_header(profile.get("request_other_ms", 0.0)), + "X-Request-Total-Ms": self._format_ms_header(profile.get("request_total_ms", 0.0)), + } + headers.update( + { + "X-Prepare-Prompt-Text-Ms": self._format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), + "X-Prepare-Target-Text-Ms": self._format_ms_header(prepare_profile.get("text_features_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), + "X-Prepare-Prompt-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), + "X-Prepare-Target-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), + "X-Prepare-Prompt-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), + "X-Prepare-Target-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), + "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Engine-GPU-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), + "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), + "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-CPU-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), + "X-Prepare-Ref-Spec-Ms": self._format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), + "X-Prepare-Ref-Spec-Wait-Ms": self._format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), + "X-Prepare-Ref-Bundle-Ms": self._format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), + "X-Prepare-Tensorize-Ms": self._format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), + "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), + } + ) + return headers + + def _build_scheduler_debug_request_profile( + self, + *, + state: T2SRequestState, + item: T2SFinishedItem, + batch_request_count: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + batch_request_total_ms: float, + ) -> Dict[str, Any]: + prepare_profile = dict(state.prepare_profile) + prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) + return { + "backend": "scheduler_debug", + "backend_mode": "scheduler_debug", + "batch_request_count": int(batch_request_count), + "batch_prepare_wall_ms": float(prepare_batch_wall_ms), + "batch_decode_wall_ms": float(decode_batch_wall_ms), + "batch_request_total_ms": float(batch_request_total_ms), + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), + "prepare_profile": prepare_profile, + "decode_steps": int(item.finish_idx), + "finish_idx": int(item.finish_idx), + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_reason": item.finish_reason, + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + } + + @staticmethod + def _build_scheduler_debug_batch_profile( + *, + request_count: int, + max_steps: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + request_total_ms: float, + finished_items: Sequence[T2SFinishedItem], + ) -> Dict[str, Any]: + finish_reason_counts: Dict[str, int] = {} + total_semantic_len = 0 + for item in finished_items: + finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 + total_semantic_len += int(item.semantic_tokens.shape[0]) + return { + "request_count": int(request_count), + "max_steps": int(max_steps), + "prepare_batch_wall_ms": float(prepare_batch_wall_ms), + "decode_batch_wall_ms": float(decode_batch_wall_ms), + "request_total_ms": float(request_total_ms), + "total_semantic_len": int(total_semantic_len), + "finish_reason_counts": finish_reason_counts, + } + + def _normalize_lang(self, value: str | None) -> str | None: + if value in [None, ""]: + return value + return str(value).lower() + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + totals: Dict[str, float] = {} + for item in items: + for key, value in item.items(): + if isinstance(value, (int, float)): + totals[key] = totals.get(key, 0.0) + float(value) + return totals + + def _apply_default_reference(self, req: dict) -> dict: + normalized = dict(req) + default_ref = self.reference_registry.get_default() + if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: + normalized["ref_audio_path"] = default_ref.ref_audio_path + if "text_lang" in normalized: + normalized["text_lang"] = self._normalize_lang(normalized.get("text_lang")) + if "prompt_lang" in normalized: + normalized["prompt_lang"] = self._normalize_lang(normalized.get("prompt_lang")) + return normalized + + def check_params(self, req: dict) -> Optional[str]: + text = req.get("text", "") + text_lang = req.get("text_lang", "") + ref_audio_path = req.get("ref_audio_path", "") + media_type = req.get("media_type", "wav") + prompt_lang = req.get("prompt_lang", "") + text_split_method = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return "ref_audio_path is required" + if text in [None, ""]: + return "text is required" + if text_lang in [None, ""]: + return "text_lang is required" + if text_lang.lower() not in self.tts.configs.languages: + return f"text_lang: {text_lang} is not supported in version {self.tts.configs.version}" + if prompt_lang in [None, ""]: + return "prompt_lang is required" + if prompt_lang.lower() not in self.tts.configs.languages: + return f"prompt_lang: {prompt_lang} is not supported in version {self.tts.configs.version}" + if media_type not in ["wav", "raw", "ogg", "aac"]: + return f"media_type: {media_type} is not supported" + if text_split_method not in self.cut_method_names: + return f"text_split_method:{text_split_method} is not supported" + return None + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return { + "request_id": None, + "text": None, + "text_lang": None, + "ref_audio_path": None, + "aux_ref_audio_paths": None, + "prompt_text": "", + "prompt_lang": None, + "top_k": 15, + "top_p": 1.0, + "temperature": 1.0, + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "return_fragment": False, + "fixed_length_chunk": False, + "response_streaming": False, + "parallel_infer": False, + "repetition_penalty": 1.35, + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + "early_stop_num": -1, + "ready_step": 0, + "timeout_sec": None, + } + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + if isinstance(payload, NormalizedEngineRequest): + normalized_payload = payload.to_payload() + else: + normalized_payload = self._base_request_defaults() + normalized_payload.update(dict(payload)) + if request_id not in [None, ""]: + normalized_payload["request_id"] = str(request_id) + elif normalized_payload.get("request_id") in [None, ""]: + raise ValueError("request_id is required after normalization") + normalized_payload = self._apply_default_reference(normalized_payload) + if normalize_streaming: + normalized_payload = self._normalize_streaming_mode(normalized_payload) + error = self.check_params(normalized_payload) + if error is not None: + raise ValueError(f"{error_prefix}{error}") + timeout_sec = normalized_payload.get("timeout_sec") + if timeout_sec in [None, ""]: + parsed_timeout = None + else: + parsed_timeout = float(timeout_sec) + aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") + if aux_ref_audio_paths in [None, "", []]: + normalized_aux_ref_audio_paths = None + else: + normalized_aux_ref_audio_paths = [str(item) for item in aux_ref_audio_paths] + return NormalizedEngineRequest( + request_id=str(normalized_payload["request_id"]), + text=str(normalized_payload["text"]), + text_lang=str(normalized_payload["text_lang"]), + ref_audio_path=str(normalized_payload["ref_audio_path"]), + prompt_lang=str(normalized_payload["prompt_lang"]), + prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), + aux_ref_audio_paths=normalized_aux_ref_audio_paths, + top_k=int(normalized_payload["top_k"]), + top_p=float(normalized_payload["top_p"]), + temperature=float(normalized_payload["temperature"]), + repetition_penalty=float(normalized_payload["repetition_penalty"]), + early_stop_num=int(normalized_payload.get("early_stop_num", -1)), + ready_step=int(normalized_payload.get("ready_step", 0)), + text_split_method=str(normalized_payload["text_split_method"]), + batch_size=int(normalized_payload["batch_size"]), + batch_threshold=float(normalized_payload["batch_threshold"]), + split_bucket=bool(normalized_payload["split_bucket"]), + speed_factor=float(normalized_payload["speed_factor"]), + fragment_interval=float(normalized_payload["fragment_interval"]), + seed=int(normalized_payload["seed"]), + media_type=str(normalized_payload["media_type"]), + streaming_mode=normalized_payload["streaming_mode"], + return_fragment=bool(normalized_payload.get("return_fragment", False)), + fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), + response_streaming=bool(normalized_payload.get("response_streaming", False)), + parallel_infer=bool(normalized_payload["parallel_infer"]), + sample_steps=int(normalized_payload["sample_steps"]), + super_sampling=bool(normalized_payload["super_sampling"]), + overlap_length=int(normalized_payload["overlap_length"]), + min_chunk_length=int(normalized_payload["min_chunk_length"]), + timeout_sec=parsed_timeout, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + normalized = dict(req) + streaming_mode = normalized.get("streaming_mode", False) + return_fragment = normalized.get("return_fragment", False) + if streaming_mode is False: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 0: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 1 or streaming_mode is True: + normalized["streaming_mode"] = False + normalized["return_fragment"] = True + normalized["fixed_length_chunk"] = False + elif streaming_mode == 2: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 3: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = True + else: + raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") + normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) + return normalized + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return aux_ref_audio_paths not in [None, [], ()] + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + if normalized.response_streaming: + if normalized.return_fragment or normalized.fixed_length_chunk: + return "legacy_direct_fragment", "fragment_streaming_mode" + return "legacy_direct_streaming", "streaming_mode" + if self._is_aux_ref_enabled(normalized.aux_ref_audio_paths): + return "legacy_direct_aux_ref", "aux_ref_audio_paths" + if normalized.super_sampling: + return "legacy_direct_super_sampling", "super_sampling" + if normalized.prompt_text in [None, ""]: + return "legacy_direct_missing_prompt", "missing_prompt_text" + return "scheduler_v1_direct", None + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + payload = normalized.to_payload() + media_type = normalized.media_type + request_id = normalized.request_id + request_start = time.perf_counter() + chunk_count = 0 + stream_total_bytes = 0 + first_chunk_ms: float | None = None + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + try: + with self.direct_tts_lock: + tts_generator = self.tts.run(payload) + first_chunk = True + current_media_type = media_type + for sr, chunk in tts_generator: + if first_chunk: + first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + self._update_request_state( + request_id, + EngineStatus.STREAMING, + { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "sample_rate": int(sr), + }, + ) + if first_chunk and media_type == "wav": + header = wave_header_chunk(sample_rate=sr) + chunk_count += 1 + stream_total_bytes += len(header) + yield header + current_media_type = "raw" + first_chunk = False + elif first_chunk: + first_chunk = False + packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_chunk) + yield packed_chunk + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + self._complete_request_state( + request_id, + dict( + self._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + audio_bytes=stream_total_bytes, + chunk_count=chunk_count, + stream_total_bytes=stream_total_bytes, + first_chunk_ms=first_chunk_ms, + ), + streaming_completed=True, + ), + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + if isinstance(req, NormalizedEngineRequest): + normalized = req + else: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + ) + backend, _ = self._select_direct_backend(normalized) + return backend == "scheduler_v1_direct" + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized + return self.tts.text_preprocessor.pre_seg_text( + str(payload["text"]), + str(payload["text_lang"]), + str(payload.get("text_split_method", "cut5")), + ) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + payload = normalized.to_payload() + payload["request_id"] = request_id + payload["text"] = text + payload["streaming_mode"] = False + payload["return_fragment"] = False + payload["fixed_length_chunk"] = False + payload["response_streaming"] = False + return self._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + request_start = time.perf_counter() + request_id = normalized.request_id + media_type = normalized.media_type + segment_texts = self._segment_direct_text(normalized) + if not segment_texts: + raise ValueError("text preprocessing returned no valid segments") + self._update_request_state( + request_id, + EngineStatus.CPU_PREPARING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, + ) + segment_specs: List[SchedulerRequestSpec] = [] + for segment_index, segment_text in enumerate(segment_texts): + segment_request = self._build_segment_request( + normalized, + request_id=f"{request_id}_seg_{segment_index:03d}", + text=segment_text, + ) + segment_specs.append(self.build_scheduler_submit_spec(segment_request)) + + prepared_items = await asyncio.gather( + *[ + self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=None, + ) + for spec in segment_specs + ] + ) + prepare_profiles: List[Dict[str, Any]] = [] + loop = asyncio.get_running_loop() + done_futures: List[asyncio.Future] = [] + self._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, + ) + for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profiles.append( + { + "request_id": spec.request_id, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": dict(state.prepare_profile), + } + ) + done_future = loop.create_future() + done_futures.append(done_future) + await self._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, + ) + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) + jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) + for profile_item, job in zip(prepare_profiles, jobs): + profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) + profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) + self._merge_request_state_profile( + request_id, + { + "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), + "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), + "prepare_aggregate": self._aggregate_numeric_dicts( + [item["prepare_profile"] for item in prepare_profiles] + ), + }, + ) + + sample_rate: int | None = None + audio_parts: List[np.ndarray] = [] + worker_profiles: List[Dict[str, Any]] = [] + fragment_interval = float(normalized.fragment_interval) + silence_chunk: Optional[np.ndarray] = None + for job in jobs: + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + if sample_rate is None: + sample_rate = int(job.sample_rate) + silence_samples = int(fragment_interval * float(sample_rate)) + if silence_samples > 0: + silence_chunk = np.zeros(silence_samples, dtype=np.int16) + elif int(job.sample_rate) != sample_rate: + raise RuntimeError("segment sample rate mismatch") + audio_parts.append(job.audio_data) + if silence_chunk is not None: + audio_parts.append(silence_chunk.copy()) + worker_profiles.append(dict(job.result)) + if sample_rate is None or not audio_parts: + raise RuntimeError("direct scheduler backend produced no audio") + self._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + merged_audio = np.concatenate(audio_parts, axis=0) + pack_start = time.perf_counter() + audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + direct_profile = self._build_direct_scheduler_profile( + backend="scheduler_v1_direct", + request_start=request_start, + response_ready_at=time.perf_counter(), + audio_bytes=len(audio_bytes), + sample_rate=int(sample_rate), + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=pack_ms, + response_overhead_ms=0.0, + ) + self._complete_request_state( + request_id, + dict(direct_profile, streaming_completed=False), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=audio_bytes, + request_id=request_id, + ) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + normalized_payload = normalized.to_payload() + request_id = normalized.request_id + media_type = normalized.media_type + request_start = time.perf_counter() + self._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + with self.direct_tts_lock: + tts_generator = self.tts.run(normalized_payload) + try: + sr, audio_data = next(tts_generator) + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + self._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + pack_start = time.perf_counter() + packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + self._complete_request_state( + request_id, + dict( + self._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + sample_rate=int(sr), + audio_bytes=len(packed_audio), + pack_ms=pack_ms, + ), + streaming_completed=False, + ), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=packed_audio, + request_id=request_id, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + if normalized.response_streaming: + return DirectTTSExecution( + media_type=normalized.media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=normalized.request_id, + ) + return await asyncio.to_thread( + self._run_legacy_direct_tts_blocking, + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self._select_direct_backend(normalized) + self._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + deadline_ts=( + time.perf_counter() + float(normalized.timeout_sec) + if normalized.timeout_sec is not None + else None + ), + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend == "scheduler_v1_direct": + try: + return await self._run_direct_tts_via_scheduler(normalized) + except Exception as exc: + self._fail_request_state(request_id, str(exc)) + raise + return await self._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self._select_direct_backend(normalized) + if not self._has_active_request(request_id): + self._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend != "scheduler_v1_direct": + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ) + + def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, payload in enumerate(request_items): + normalized = self._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"req_{index:03d}"), + error_prefix=f"request[{index}] 参数非法: ", + ) + specs.append(normalized.to_scheduler_spec()) + return specs + + def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + normalized = self._normalize_engine_request( + payload, + request_id=( + payload.request_id + if isinstance(payload, NormalizedEngineRequest) + else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") + ), + ) + return normalized.to_scheduler_spec() + + @staticmethod + def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + @staticmethod + def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + request_start = time.perf_counter() + set_scheduler_seed(seed) + specs = self.build_scheduler_request_specs(request_items) + request_ids = [spec.request_id for spec in specs] + for spec in specs: + self._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_debug", + backend="scheduler_debug", + media_type="wav", + response_streaming=False, + meta={ + "text_len": len(spec.text), + "prompt_text_len": len(spec.prompt_text), + "text_lang": spec.text_lang, + "prompt_lang": spec.prompt_lang, + "ref_audio_path": str(spec.ref_audio_path), + "ready_step": int(spec.ready_step), + }, + ) + self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) + self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) + prepare_started_at = time.perf_counter() + try: + states = await self.scheduler_worker.prepare_states_batch_async(specs) + except Exception as exc: + for request_id in request_ids: + self._fail_request_state(request_id, str(exc)) + raise + prepare_finished_at = time.perf_counter() + prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) + for state in states: + self._update_request_state( + state.request_id, + EngineStatus.ACTIVE_DECODE, + { + "prepare_profile": dict(state.prepare_profile), + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + }, + ) + decode_started_at = time.perf_counter() + try: + finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) + except Exception as exc: + for request_id in request_ids: + self._fail_request_state(request_id, str(exc)) + raise + decode_finished_at = time.perf_counter() + decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) + request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) + finished_map = {item.request_id: item for item in finished} + request_profiles: List[Dict[str, Any]] = [] + for state in states: + item = finished_map.get(state.request_id) + if item is None: + self._fail_request_state(state.request_id, "scheduler_debug finished without result") + continue + request_profile = self._build_scheduler_debug_request_profile( + state=state, + item=item, + batch_request_count=len(states), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + batch_request_total_ms=request_total_ms, + ) + request_profiles.append( + { + "request_id": state.request_id, + "profile": dict(request_profile), + } + ) + self._complete_request_state( + state.request_id, + dict(request_profile), + ) + return SchedulerDebugExecution( + payload={ + "message": "success", + "request_count": len(states), + "max_steps": int(max_steps), + "batch_profile": self._build_scheduler_debug_batch_profile( + request_count=len(states), + max_steps=int(max_steps), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + request_total_ms=request_total_ms, + finished_items=finished, + ), + "requests": self.summarize_scheduler_states(states), + "finished": self.summarize_scheduler_finished(finished), + "request_profiles": request_profiles, + "request_traces": self._collect_request_summaries(request_ids), + } + ) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + request_start = time.perf_counter() + prepare_start = request_start + normalized = self._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), + ) + spec = self.build_scheduler_submit_spec(normalized) + deadline_ts = None + timeout_sec = normalized.timeout_sec + if timeout_sec is not None: + try: + deadline_ts = request_start + float(timeout_sec) + except Exception: + deadline_ts = None + self._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_submit", + backend="scheduler_v1", + media_type=normalized.media_type, + response_streaming=False, + deadline_ts=deadline_ts, + meta=self._build_request_meta(normalized.to_payload()), + ) + self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = await self._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=spec_ready_at, + engine_request_id=spec.request_id, + ) + except Exception as exc: + self._fail_request_state(spec.request_id, str(exc)) + raise + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) + prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) + prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile = dict(state.prepare_profile) + prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) + self._update_request_state( + spec.request_id, + EngineStatus.READY_FOR_PREFILL, + { + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": prepare_profile, + }, + ) + api_after_prepare_start = time.perf_counter() + loop = asyncio.get_running_loop() + done_future = loop.create_future() + await self._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=spec.request_id, + timeout_sec=normalized.timeout_sec, + ) + api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) + try: + job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) + except Exception as exc: + self._fail_request_state(spec.request_id, str(exc)) + raise + wait_return_at = time.perf_counter() + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + self._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") + raise RuntimeError(f"{job.request_id} finished without audio result") + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + submit_profile = self._build_scheduler_submit_profile( + backend="scheduler_v1", + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=len(audio_data), + sample_rate=int(job.sample_rate), + prepare_spec_build_ms=prepare_spec_build_ms, + prepare_wall_ms=prepare_wall_ms, + prepare_executor_queue_ms=prepare_executor_queue_ms, + prepare_executor_run_ms=prepare_executor_run_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + prepare_profile_wall_ms=prepare_profile_wall_ms, + prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), + api_after_prepare_ms=api_after_prepare_ms, + api_wait_result_ms=api_wait_result_ms, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + worker_profile=dict(job.result or {}), + ) + headers = self._build_scheduler_submit_headers( + request_id=job.request_id, + media_type=job.media_type, + sample_rate=int(job.sample_rate), + profile=submit_profile, + ) + self._merge_request_state_profile( + spec.request_id, + dict(submit_profile, response_headers_emitted=True), + ) + return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py new file mode 100644 index 00000000..5c3bd7a5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import subprocess +import threading +import wave +from io import BytesIO + +import numpy as np +import soundfile as sf +import torch + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except (RuntimeError, ValueError): + handle_pack_ogg() + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", + "-ar", + str(rate), + "-ac", + "1", + "-i", + "pipe:0", + "-c:a", + "aac", + "-b:a", + "192k", + "-vn", + "-f", + "adts", + "pipe:1", + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + wav_buf.seek(0) + return wav_buf.read() + + diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py new file mode 100644 index 00000000..536efbc5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def request_registry(self): + return self.owner.request_registry + + @property + def engine_prepare_queue_owner(self): + return self.owner.engine_prepare_queue_owner + + @property + def engine_finalize_queue_owner(self): + return self.owner.engine_finalize_queue_owner + + @property + def engine_dispatch_queue_owner(self): + return self.owner.engine_dispatch_queue_owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_job_registry(self): + return self.owner.engine_job_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_stage_coordinator(self): + return self.owner.engine_stage_coordinator + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.request_registry.register( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.request_registry.update(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.merge_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.complete(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.request_registry.fail(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.request_registry.snapshot() + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self._enqueue_worker_finished_for_finalize(tasks) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job) + self._notify_engine_arbiter() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.engine_stage_coordinator.has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state + + def _run_engine_prepare_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.engine_stage_coordinator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py new file mode 100644 index 00000000..45178b1f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import os +import threading +from typing import Any + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineArbiterConfig, + EngineDecodeRuntimeOwner, + EnginePolicyArbiterController, + EnginePolicyConfig, + EngineRequestRegistry, + EngineTaskQueueOwner, + ModelRegistry, + ReferenceRegistry, + RuntimeStateCallbacks, + SchedulerJobRegistry, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineCompositionBuilder: + def __init__(self, owner: Any) -> None: + self.owner = owner + + def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None: + self._init_registries_and_locks() + self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) + self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms) + self._init_runtime_owners() + self._init_stage_coordinator() + self._init_arbiter() + self._init_facades() + self._start_arbiter_thread() + + def _init_registries_and_locks(self) -> None: + owner = self.owner + owner.reference_registry = ReferenceRegistry() + owner.model_registry = ModelRegistry( + t2s_weights_path=str(owner.tts.configs.t2s_weights_path), + vits_weights_path=str(owner.tts.configs.vits_weights_path), + ) + owner.request_registry = EngineRequestRegistry( + recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + ) + owner.engine_job_registry = SchedulerJobRegistry(threading.Lock()) + owner.direct_tts_lock = threading.RLock() + owner.management_lock = threading.RLock() + owner.engine_dispatch_last_snapshot = {} + + def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None: + owner = self.owner + owner.scheduler_worker = UnifiedSchedulerWorker( + owner.tts, + max_steps=max_steps, + micro_batch_wait_ms=micro_batch_wait_ms, + runtime_callbacks=RuntimeStateCallbacks( + update=owner._update_request_state, + complete=owner._complete_request_state, + fail=owner._fail_request_state, + decode_runtime_update=owner._update_engine_decode_runtime_state, + ), + external_finalize_submit=owner._enqueue_worker_finished_for_finalize, + ) + + def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None: + owner = self.owner + worker_capacity_limits = owner.scheduler_worker.get_capacity_limits() + prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight()) + owner.engine_policy_config = EnginePolicyConfig( + enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True), + poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_backlog_soft_max=max( + 0, + owner._env_int( + "GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX", + int(worker_capacity_limits["decode_backlog_max"]), + ), + ), + finalize_pending_soft_max=max( + 0, + owner._env_int( + "GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX", + int(worker_capacity_limits["finalize_pending_max"]), + ), + ), + prepare_inflight_soft_max=max( + 0, + owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight), + ), + active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)), + ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)), + active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)), + ) + owner.engine_arbiter_config = EngineArbiterConfig( + poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)), + prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)), + finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)), + ) + + def _init_runtime_owners(self) -> None: + owner = self.owner + owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner( + get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters, + get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s, + ) + owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") + + def _init_stage_coordinator(self) -> None: + owner = self.owner + owner.engine_stage_coordinator = EngineStageCoordinator( + tts=owner.tts, + scheduler_worker=owner.scheduler_worker, + prepare_queue_owner=owner.engine_prepare_queue_owner, + finalize_queue_owner=owner.engine_finalize_queue_owner, + dispatch_queue_owner=owner.engine_dispatch_queue_owner, + decode_runtime_owner=owner.engine_decode_runtime_owner, + update_request_state=owner._update_request_state, + merge_request_state_profile=owner._merge_request_state_profile, + fail_request_state=owner._fail_request_state, + get_engine_job=owner._get_engine_job, + register_engine_job=owner._register_engine_job, + fail_engine_jobs=owner._fail_engine_jobs, + complete_engine_job=owner._complete_engine_job, + add_engine_prefill_time=owner._add_engine_prefill_time, + add_engine_merge_time=owner._add_engine_merge_time, + add_engine_decode_time=owner._add_engine_decode_time, + enqueue_engine_finished_items=owner._enqueue_engine_finished_items, + snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state, + snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state, + ) + + def _init_arbiter(self) -> None: + owner = self.owner + owner.engine_policy_arbiter = EnginePolicyArbiterController( + policy_config=owner.engine_policy_config, + arbiter_config=owner.engine_arbiter_config, + snapshot_request_registry=owner._snapshot_request_registry, + get_worker_state=owner.get_scheduler_state, + snapshot_prepare_state=owner._snapshot_engine_prepare_state, + snapshot_finalize_state=owner._snapshot_engine_finalize_state, + snapshot_dispatch_state=owner._snapshot_engine_dispatch_state, + snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state, + snapshot_job_registry=owner._snapshot_engine_job_registry, + peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms, + merge_request_state_profile=owner._merge_request_state_profile, + ) + owner.engine_stage_coordinator.bind_arbiter( + notify_arbiter=owner._notify_engine_arbiter, + select_stage=owner._select_engine_stage, + mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick( + stage=stage, + reason=reason, + policy_allowed=policy_allowed, + ), + wait_arbiter=owner.engine_policy_arbiter.wait, + ) + + def _init_facades(self) -> None: + owner = self.owner + owner.bridge_facade = EngineBridgeFacade(owner) + owner.api_facade = EngineApiFacade(owner) + owner.runtime_facade = EngineRuntimeFacade(owner) + + def _start_arbiter_thread(self) -> None: + owner = self.owner + owner.engine_arbiter_thread = threading.Thread( + target=owner._run_engine_arbiter_loop, + name="unified-engine-arbiter", + daemon=True, + ) + owner.engine_arbiter_thread.start() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py new file mode 100644 index 00000000..3a124f4e --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py @@ -0,0 +1,1150 @@ +from __future__ import annotations + +import asyncio +import os +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Deque, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + request_id: Optional[str] = None + + +@dataclass +class NormalizedEngineRequest: + request_id: str + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + aux_ref_audio_paths: List[str] | None = None + top_k: int = 15 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = False + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool | int = False + return_fragment: bool = False + fixed_length_chunk: bool = False + response_streaming: bool = False + parallel_infer: bool = False + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + timeout_sec: float | None = None + + def to_payload(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "text": self.text, + "text_lang": self.text_lang, + "ref_audio_path": self.ref_audio_path, + "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, + "prompt_text": self.prompt_text, + "prompt_lang": self.prompt_lang, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "text_split_method": self.text_split_method, + "batch_size": self.batch_size, + "batch_threshold": self.batch_threshold, + "speed_factor": self.speed_factor, + "split_bucket": self.split_bucket, + "fragment_interval": self.fragment_interval, + "seed": self.seed, + "media_type": self.media_type, + "streaming_mode": self.streaming_mode, + "return_fragment": self.return_fragment, + "fixed_length_chunk": self.fixed_length_chunk, + "response_streaming": self.response_streaming, + "parallel_infer": self.parallel_infer, + "repetition_penalty": self.repetition_penalty, + "sample_steps": self.sample_steps, + "super_sampling": self.super_sampling, + "overlap_length": self.overlap_length, + "min_chunk_length": self.min_chunk_length, + "early_stop_num": self.early_stop_num, + "ready_step": self.ready_step, + "timeout_sec": self.timeout_sec, + } + + def to_scheduler_spec(self) -> SchedulerRequestSpec: + return SchedulerRequestSpec( + request_id=self.request_id, + ref_audio_path=Path(self.ref_audio_path), + prompt_text=self.prompt_text, + prompt_lang=self.prompt_lang, + text=self.text, + text_lang=self.text_lang, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + early_stop_num=self.early_stop_num, + ready_step=self.ready_step, + ) + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] + + +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + +class EngineStatus: + NEW = "NEW" + QUEUED = "QUEUED" + VALIDATED = "VALIDATED" + CPU_PREPARING = "CPU_PREPARING" + GPU_PREPARING = "GPU_PREPARING" + READY_FOR_PREFILL = "READY_FOR_PREFILL" + ACTIVE_DECODE = "ACTIVE_DECODE" + READY_FOR_FINALIZE = "READY_FOR_FINALIZE" + FINALIZING = "FINALIZING" + STREAMING = "STREAMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class EngineRequestState: + request_id: str + api_mode: str + backend: str + media_type: str + response_streaming: bool + submit_ts: float + deadline_ts: float | None = None + status: str = EngineStatus.NEW + updated_ts: float = 0.0 + error: str | None = None + finish_reason: str | None = None + meta: Dict[str, Any] = field(default_factory=dict) + profile: Dict[str, Any] = field(default_factory=dict) + lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "api_mode": self.api_mode, + "backend": self.backend, + "media_type": self.media_type, + "response_streaming": self.response_streaming, + "status": self.status, + "submit_ts": self.submit_ts, + "updated_ts": self.updated_ts, + "deadline_ts": self.deadline_ts, + "error": self.error, + "finish_reason": self.finish_reason, + "meta": dict(self.meta), + "profile": dict(self.profile), + "lifecycle_timestamps": dict(self.lifecycle_timestamps), + } + + +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + requested = set(request_ids) + results: List[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + engine_request_id: str | None = None + + +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage == "prepare": + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + return "prepare", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + return "prepare", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int( + active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) + ), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + queue_wait_ms: float = 0.0 + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + +@dataclass +class RuntimeStateCallbacks: + update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None + complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None + fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None + + diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py new file mode 100644 index 00000000..7dbbd5bd --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade + + +class EngineBridgeDelegates: + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.bridge_facade._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._merge_request_state_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_prepare_state() + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_finalize_state() + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_dispatch_state() + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._register_engine_job(job) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._get_engine_job(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._pop_engine_job(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_job_registry() + + def _is_engine_drained(self) -> bool: + return self.bridge_facade._is_engine_drained() + + def _record_engine_job_done(self, request_id: str) -> None: + self.bridge_facade._record_engine_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + self.bridge_facade._fail_engine_jobs(request_ids, error) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + self.bridge_facade._enqueue_engine_finished_items(items) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineBridgeFacade._summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.bridge_facade._refresh_engine_decode_runtime_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + self.bridge_facade._update_engine_decode_runtime_state(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_runtime_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_arbiter_state() + + def _notify_engine_arbiter(self) -> None: + self.bridge_facade._notify_engine_arbiter() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._enqueue_engine_decode_pending_job(job) + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.bridge_facade._peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.bridge_facade._engine_has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.bridge_facade._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.bridge_facade._take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.bridge_facade._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + return self.bridge_facade._select_engine_stage() + + def _run_engine_prepare_once(self) -> bool: + return self.bridge_facade._run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.bridge_facade._run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.bridge_facade._run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.bridge_facade._run_engine_arbiter_loop() + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.bridge_facade._fail_request_state(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_request_registry() + + +class EngineApiDelegates: + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.api_facade._collect_request_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.api_facade._has_active_request(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + return EngineApiFacade._build_request_meta(payload) + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + return EngineApiFacade._sum_profile_field(items, key) + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + + def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_direct_scheduler_profile(**kwargs) + + def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_legacy_direct_profile(**kwargs) + + def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_submit_profile(**kwargs) + + @staticmethod + def _format_ms_header(value: Any) -> str: + return EngineApiFacade._format_ms_header(value) + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + return self.api_facade._build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, + ) + + def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_debug_request_profile(**kwargs) + + @staticmethod + def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: + return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) + + def _normalize_lang(self, value: str | None) -> str | None: + return self.api_facade._normalize_lang(value) + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + return EngineApiFacade._aggregate_numeric_dicts(items) + + def _apply_default_reference(self, req: dict) -> dict: + return self.api_facade._apply_default_reference(req) + + def check_params(self, req: dict) -> Optional[str]: + return self.api_facade.check_params(req) + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return EngineApiFacade._base_request_defaults() + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + return self.api_facade._normalize_engine_request( + payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + return EngineApiFacade._normalize_streaming_mode(req) + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return self.api_facade._select_direct_backend(normalized) + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + return self.api_facade._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + return self.api_facade._should_use_scheduler_backend_for_direct(req) + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + return self.api_facade._segment_direct_text(normalized) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_scheduler(normalized) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return self.api_facade._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + return await self.api_facade.run_direct_tts_async(req) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + return self.api_facade.run_direct_tts(req) + + def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + return self.api_facade.build_scheduler_request_specs(request_items) + + def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + return self.api_facade.build_scheduler_submit_spec(payload) + + @staticmethod + def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return EngineApiFacade.summarize_scheduler_states(states) + + @staticmethod + def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return EngineApiFacade.summarize_scheduler_finished(items) + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + return await self.api_facade.run_scheduler_submit(payload) + + +class EngineRuntimeDelegates: + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + return EngineRuntimeFacade._safe_component_snapshot(component) + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_summary(request_registry, worker_state) + + def get_scheduler_state(self) -> dict: + return self.runtime_facade.get_scheduler_state() + + def get_runtime_state(self) -> dict: + return self.runtime_facade.get_runtime_state() + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + return self.runtime_facade.set_refer_audio(refer_audio_path) + + def set_gpt_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_gpt_weights(weights_path) + + def set_sovits_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_sovits_weights(weights_path) + + def handle_control(self, command: str) -> None: + self.runtime_facade.handle_control(command) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py new file mode 100644 index 00000000..70212a4d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import os +import signal +import sys +from typing import Any, Dict, Optional + + +class EngineRuntimeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def tts(self): + return self.owner.tts + + @property + def reference_registry(self): + return self.owner.reference_registry + + @property + def model_registry(self): + return self.owner.model_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + @property + def management_lock(self): + return self.owner.management_lock + + @property + def direct_tts_lock(self): + return self.owner.direct_tts_lock + + @property + def control_callbacks(self): + return self.owner.control_callbacks + + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + if component is None or not hasattr(component, "snapshot"): + return None + try: + return dict(component.snapshot()) + except Exception: + return None + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self._build_stage_counters(request_registry, worker_state) + bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) + ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) + text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) + + return { + **counters, + "engine_drained": bool(self.owner._is_engine_drained()), + "admission_config": { + "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), + "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), + }, + "engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state), + "engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(), + "engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(), + "engine_job_registry": self.owner._snapshot_engine_job_registry(), + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": self.owner._snapshot_engine_prepare_state(), + "engine_finalize_state": self.owner._snapshot_engine_finalize_state(), + "engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(), + "active_batch": dict(worker_state.get("active_batch") or {}), + "prepare_state": dict(worker_state.get("prepare_state") or {}), + "bert_batch_worker_state": bert_worker_state, + "ref_semantic_worker_state": ref_semantic_worker_state, + "text_preprocessor_state": text_preprocessor_state, + } + + def get_scheduler_state(self) -> dict: + return self.scheduler_worker.snapshot() + + def get_runtime_state(self) -> dict: + model_state = self.model_registry.snapshot() + default_ref = self.reference_registry.get_default() + scheduler_state = self.get_scheduler_state() + request_registry = self.owner._snapshot_request_registry() + engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state) + engine_arbiter_state = self.owner._snapshot_engine_arbiter_state() + engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state() + engine_job_registry = self.owner._snapshot_engine_job_registry() + engine_prepare_state = self.owner._snapshot_engine_prepare_state() + engine_finalize_state = self.owner._snapshot_engine_finalize_state() + engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state() + engine_drained = self.owner._is_engine_drained() + return { + "message": "success", + "default_reference": { + "ref_audio_path": default_ref.ref_audio_path, + "updated_at": default_ref.updated_at, + }, + "model_registry": { + "generation": model_state.generation, + "t2s_generation": model_state.t2s_generation, + "vits_generation": model_state.vits_generation, + "t2s_weights_path": model_state.t2s_weights_path, + "vits_weights_path": model_state.vits_weights_path, + "updated_at": model_state.updated_at, + }, + "worker_state": scheduler_state, + "engine_policy": engine_policy, + "engine_arbiter_state": engine_arbiter_state, + "engine_decode_runtime_state": engine_decode_runtime_state, + "engine_job_registry": engine_job_registry, + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": engine_prepare_state, + "engine_finalize_state": engine_finalize_state, + "engine_dispatcher_state": engine_dispatcher_state, + "engine_drained": bool(engine_drained), + "request_registry": request_registry, + "stage_summary": self._build_stage_summary(request_registry, scheduler_state), + } + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec): + raise TimeoutError("scheduler worker did not drain before model reload") + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + if refer_audio_path in [None, ""]: + state = self.reference_registry.clear() + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + if not os.path.exists(str(refer_audio_path)): + raise FileNotFoundError(f"{refer_audio_path} not exists") + with self.management_lock: + with self.direct_tts_lock: + self.tts.set_ref_audio(str(refer_audio_path)) + state = self.reference_registry.set_default(str(refer_audio_path)) + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + + def set_gpt_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("gpt weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_t2s_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_t2s_reload(str(weights_path)) + return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation} + + def set_sovits_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("sovits weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_vits_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_vits_reload(str(weights_path)) + return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation} + + def handle_control(self, command: str) -> None: + if command == "restart": + if self.control_callbacks.restart is None: + os.execl(sys.executable, sys.executable, *sys.argv) + self.control_callbacks.restart() + return + if command == "exit": + if self.control_callbacks.exit is None: + os.kill(os.getpid(), signal.SIGTERM) + return + self.control_callbacks.exit() + return + raise ValueError(f"unsupported command: {command}") diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py new file mode 100644 index 00000000..65f0befe --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -0,0 +1,420 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineDecodeRuntimeOwner, + EngineDispatchTask, + EngineGpuPrepareTask, + EngineStatus, + EngineTaskQueueOwner, + SchedulerFinalizeTask, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageCoordinator: + def __init__( + self, + *, + tts: TTS, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + fail_request_state: Callable[[str, str], None], + get_engine_job: Callable[[str], SchedulerPendingJob | None], + register_engine_job: Callable[[SchedulerPendingJob], None], + fail_engine_jobs: Callable[[List[str], str], None], + complete_engine_job: Callable[..., None], + add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None], + add_engine_merge_time: Callable[[List[str], float], None], + add_engine_decode_time: Callable[[List[str], float], None], + enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None], + snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.tts = tts + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.update_request_state = update_request_state + self.merge_request_state_profile = merge_request_state_profile + self.fail_request_state = fail_request_state + self.get_engine_job = get_engine_job + self.register_engine_job = register_engine_job + self.fail_engine_jobs = fail_engine_jobs + self.complete_engine_job = complete_engine_job + self.add_engine_prefill_time = add_engine_prefill_time + self.add_engine_merge_time = add_engine_merge_time + self.add_engine_decode_time = add_engine_decode_time + self.enqueue_engine_finished_items = enqueue_engine_finished_items + self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._notify_arbiter: Callable[[], None] | None = None + self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None + self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None + self._wait_arbiter: Callable[[], None] | None = None + + def bind_arbiter( + self, + *, + notify_arbiter: Callable[[], None], + select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]], + mark_arbiter_tick: Callable[[str, str, bool], None], + wait_arbiter: Callable[[], None], + ) -> None: + self._notify_arbiter = notify_arbiter + self._select_stage = select_stage + self._mark_arbiter_tick = mark_arbiter_tick + self._wait_arbiter = wait_arbiter + + def notify_arbiter(self) -> None: + if self._notify_arbiter is not None: + self._notify_arbiter() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass + + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self.update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.prepare_queue_owner.enqueue(task) + self.notify_arbiter() + state, prepare_exec_started_at, prepare_exec_finished_at = await done_future + return state, prepare_exec_started_at, prepare_exec_finished_at + + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + self.update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.finalize_queue_owner.enqueue_many(tasks) + self.notify_arbiter() + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.dispatch_queue_owner.enqueue(task) + self.notify_arbiter() + self.merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int( + self.snapshot_engine_dispatch_state()["waiting_count"] + ), + }, + ) + return task + + def peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "finalize": + return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + if queue_name == "decode_runtime_pending": + return self.decode_runtime_owner.pending_age_ms() + return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( + "active_request_count", 0 + ) > 0: + return True + if self.prepare_queue_owner.has_items(): + return True + if self.finalize_queue_owner.has_items(): + return True + return self.dispatch_queue_owner.has_items() + + def run_engine_prepare_once(self) -> bool: + task = self.prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True + + def run_engine_finalize_once(self) -> bool: + tasks = self.take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self.update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self.register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self.decode_runtime_owner.enqueue_pending_job(worker_job) + self.notify_arbiter() + self.dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True + + def run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self.snapshot_engine_decode_runtime_state() + pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self.add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self.add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self.decode_runtime_owner.refresh_state("engine_decode_cycle") + return bool(result.get("executed", False)) + + def run_engine_arbiter_loop(self) -> None: + if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: + raise RuntimeError("arbiter callbacks are not bound") + while True: + if not self.has_pending_work(): + self._mark_arbiter_tick("idle", "no_pending_work", True) + self._wait_arbiter() + continue + stage, reason, policy_snapshot, worker_state = self._select_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self.run_engine_prepare_once() + elif stage == "finalize": + executed = self.run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self.run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self.run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) + self._wait_arbiter() + continue + self._mark_arbiter_tick(stage, reason, policy_allowed) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py new file mode 100644 index 00000000..04d9090f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py @@ -0,0 +1,1510 @@ +from __future__ import annotations + +import asyncio +import os +import threading +import time +from collections import deque +from typing import Any, Callable, Deque, Dict, List, Optional + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState, decode_one_step, merge_active_batches, run_prefill_active_batch, run_scheduler_continuous +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry, SchedulerPendingJob + + +class WorkerPrepareExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) + + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) + + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: + with self.condition: + return { + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), + } + + def is_idle(self) -> bool: + with self.condition: + return self.inflight <= 0 and not self.pending_tasks + + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ), + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results + + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.complete is None: + return + self.runtime_callbacks.complete(request_id, extra) + + def runtime_fail(self, request_id: str | None, error: str) -> None: + if request_id is None or self.runtime_callbacks.fail is None: + return + self.runtime_callbacks.fail(request_id, error) + + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: + while True: + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: + with self.condition: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) + + +class UnifiedSchedulerWorker: + def __init__( + self, + tts: TTS, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + runtime_callbacks: RuntimeStateCallbacks | None = None, + external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ): + self.tts = tts + self.max_steps = int(max_steps) + self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.condition = threading.Condition() + self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks) + self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps) + self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s) + self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks) + self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change) + self.finalize_executor = WorkerFinalizeExecutor( + tts, + on_state_change=self._notify_worker_state_change, + external_submit=external_finalize_submit, + ) + self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) + self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) + self.engine_decode_control_enabled = ( + str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} + ) + self.job_registry = SchedulerJobRegistry(self.condition) + self.worker_thread: threading.Thread | None = None + if not self.engine_decode_control_enabled: + self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) + self.worker_thread.start() + self.finalize_threads = [] + if external_finalize_submit is None: + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"unified-t2s-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_executor.get_worker_count()) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() + + def _notify_worker_state_change(self) -> None: + with self.condition: + self.condition.notify_all() + + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + last_snapshot: Dict[str, int] = {} + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + last_snapshot = snapshot + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: + self._runtime_update( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + }, + ) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() + + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) + + def _run_loop(self) -> None: + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) + + From 800f01790e243be3184681932d9747b90762b9d9 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 17:58:20 +0800 Subject: [PATCH 14/24] Refactor EngineApiFacade and EngineApiDelegates for improved method naming and structure Rename several methods in EngineApiFacade to follow a consistent private naming convention, enhancing code clarity. Update EngineApiDelegates to remove redundant method definitions, streamlining the interface. Introduce EnginePublicInterface to encapsulate public API methods, improving organization and maintainability of the TTS system. Additionally, update the EngineCompositionBuilder to use the new scheduler worker state retrieval method. --- GPT_SoVITS/TTS_infer_pack/unified_engine.py | 3 +- .../TTS_infer_pack/unified_engine_api.py | 18 +++---- .../TTS_infer_pack/unified_engine_builder.py | 2 +- .../unified_engine_delegates.py | 45 ---------------- .../TTS_infer_pack/unified_engine_public.py | 53 +++++++++++++++++++ 5 files changed, 65 insertions(+), 56 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_public.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py index 24e1c98b..a6faaddb 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -7,9 +7,10 @@ from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface -class UnifiedTTSEngine(EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): +class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): @staticmethod def _env_flag(name: str, default: bool) -> bool: value = os.environ.get(name) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py index bcc8bf0d..ca76252d 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -826,7 +826,7 @@ class EngineApiFacade: request_id=f"{request_id}_seg_{segment_index:03d}", text=segment_text, ) - segment_specs.append(self.build_scheduler_submit_spec(segment_request)) + segment_specs.append(self._build_scheduler_submit_spec(segment_request)) prepared_items = await asyncio.gather( *[ @@ -1131,7 +1131,7 @@ class EngineApiFacade: fallback_reason="sync_direct_compat", ) - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: specs: List[SchedulerRequestSpec] = [] for index, payload in enumerate(request_items): normalized = self._normalize_engine_request( @@ -1142,7 +1142,7 @@ class EngineApiFacade: specs.append(normalized.to_scheduler_spec()) return specs - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: normalized = self._normalize_engine_request( payload, request_id=( @@ -1154,7 +1154,7 @@ class EngineApiFacade: return normalized.to_scheduler_spec() @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: return [ { "request_id": state.request_id, @@ -1169,7 +1169,7 @@ class EngineApiFacade: ] @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: return [ { "request_id": item.request_id, @@ -1183,7 +1183,7 @@ class EngineApiFacade: async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: request_start = time.perf_counter() set_scheduler_seed(seed) - specs = self.build_scheduler_request_specs(request_items) + specs = self._build_scheduler_request_specs(request_items) request_ids = [spec.request_id for spec in specs] for spec in specs: self._register_request_state( @@ -1270,8 +1270,8 @@ class EngineApiFacade: request_total_ms=request_total_ms, finished_items=finished, ), - "requests": self.summarize_scheduler_states(states), - "finished": self.summarize_scheduler_finished(finished), + "requests": self._summarize_scheduler_states(states), + "finished": self._summarize_scheduler_finished(finished), "request_profiles": request_profiles, "request_traces": self._collect_request_summaries(request_ids), } @@ -1284,7 +1284,7 @@ class EngineApiFacade: payload, request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), ) - spec = self.build_scheduler_submit_spec(normalized) + spec = self._build_scheduler_submit_spec(normalized) deadline_ts = None timeout_sec = normalized.timeout_sec if timeout_sec is not None: diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py index 45178b1f..0e93c442 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -143,7 +143,7 @@ class EngineCompositionBuilder: policy_config=owner.engine_policy_config, arbiter_config=owner.engine_arbiter_config, snapshot_request_registry=owner._snapshot_request_registry, - get_worker_state=owner.get_scheduler_state, + get_worker_state=owner.scheduler_worker.snapshot, snapshot_prepare_state=owner._snapshot_engine_prepare_state, snapshot_finalize_state=owner._snapshot_engine_finalize_state, snapshot_dispatch_state=owner._snapshot_engine_dispatch_state, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py index 7dbbd5bd..f68d3ede 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -360,33 +360,6 @@ class EngineApiDelegates: fallback_reason=fallback_reason, ) - async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: - return await self.api_facade.run_direct_tts_async(req) - - def run_direct_tts(self, req: dict) -> DirectTTSExecution: - return self.api_facade.run_direct_tts(req) - - def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: - return self.api_facade.build_scheduler_request_specs(request_items) - - def build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: - return self.api_facade.build_scheduler_submit_spec(payload) - - @staticmethod - def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: - return EngineApiFacade.summarize_scheduler_states(states) - - @staticmethod - def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: - return EngineApiFacade.summarize_scheduler_finished(items) - - async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: - return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) - - async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: - return await self.api_facade.run_scheduler_submit(payload) - - class EngineRuntimeDelegates: @staticmethod def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: @@ -424,23 +397,5 @@ class EngineRuntimeDelegates: ) -> Dict[str, Any]: return self.runtime_facade._build_stage_summary(request_registry, worker_state) - def get_scheduler_state(self) -> dict: - return self.runtime_facade.get_scheduler_state() - - def get_runtime_state(self) -> dict: - return self.runtime_facade.get_runtime_state() - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) - - def set_refer_audio(self, refer_audio_path: str | None) -> dict: - return self.runtime_facade.set_refer_audio(refer_audio_path) - - def set_gpt_weights(self, weights_path: str) -> dict: - return self.runtime_facade.set_gpt_weights(weights_path) - - def set_sovits_weights(self, weights_path: str) -> dict: - return self.runtime_facade.set_sovits_weights(weights_path) - - def handle_control(self, command: str) -> None: - self.runtime_facade.handle_control(command) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py new file mode 100644 index 00000000..fbe88b85 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution + + +class EnginePublicInterface: + PUBLIC_API_METHODS = ( + "run_direct_tts_async", + "run_scheduler_submit", + "run_scheduler_debug", + "get_runtime_state", + "set_refer_audio", + "set_gpt_weights", + "set_sovits_weights", + "handle_control", + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + return await self.api_facade.run_direct_tts_async(req) + + async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + return await self.api_facade.run_scheduler_submit(payload) + + def get_runtime_state(self) -> dict: + return self.runtime_facade.get_runtime_state() + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + return self.runtime_facade.set_refer_audio(refer_audio_path) + + def set_gpt_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_gpt_weights(weights_path) + + def set_sovits_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_sovits_weights(weights_path) + + def handle_control(self, command: str) -> None: + self.runtime_facade.handle_control(command) + + +class EngineCompatInterface: + COMPAT_API_METHODS = ( + "run_direct_tts", + "get_scheduler_state", + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + return self.api_facade.run_direct_tts(req) + + def get_scheduler_state(self) -> dict: + return self.runtime_facade.get_scheduler_state() From b046a093d3f03b139d98e81dbb26a18e21cc5336 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 18:35:47 +0800 Subject: [PATCH 15/24] Add unified engine delegates and orchestration components for enhanced TTS processing Introduce new modules including EngineApiDelegates, EngineBridgeDelegates, EngineRegistryBridgeFacade, EngineRuntimeBridgeFacade, EngineStageBridgeFacade, and EngineStageOrchestrator. These additions provide a structured approach to managing TTS requests, engine states, and orchestration, significantly improving the architecture and maintainability of the TTS system. The new components support asynchronous operations and enhance overall performance through better request handling and processing capabilities. --- .../unified_engine_api_delegates.py | 165 +++++++ .../TTS_infer_pack/unified_engine_bridge.py | 313 +------------ .../unified_engine_bridge_delegates.py | 200 +++++++++ .../unified_engine_bridge_registry.py | 193 +++++++++ .../unified_engine_bridge_runtime.py | 33 ++ .../unified_engine_bridge_stage.py | 114 +++++ .../unified_engine_delegates.py | 410 +----------------- .../unified_engine_orchestration.py | 92 ++++ .../unified_engine_runtime_delegates.py | 46 ++ .../TTS_infer_pack/unified_engine_stage.py | 381 +++------------- .../unified_engine_stage_executor.py | 358 +++++++++++++++ 11 files changed, 1280 insertions(+), 1025 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py new file mode 100644 index 00000000..f42ec233 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, NormalizedEngineRequest + + +class EngineApiDelegates: + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.api_facade._collect_request_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.api_facade._has_active_request(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + return EngineApiFacade._build_request_meta(payload) + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + return EngineApiFacade._sum_profile_field(items, key) + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + + def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_direct_scheduler_profile(**kwargs) + + def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_legacy_direct_profile(**kwargs) + + def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_submit_profile(**kwargs) + + @staticmethod + def _format_ms_header(value: Any) -> str: + return EngineApiFacade._format_ms_header(value) + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + return self.api_facade._build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, + ) + + def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_debug_request_profile(**kwargs) + + @staticmethod + def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: + return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) + + def _normalize_lang(self, value: str | None) -> str | None: + return self.api_facade._normalize_lang(value) + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + return EngineApiFacade._aggregate_numeric_dicts(items) + + def _apply_default_reference(self, req: dict) -> dict: + return self.api_facade._apply_default_reference(req) + + def check_params(self, req: dict) -> Optional[str]: + return self.api_facade.check_params(req) + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return EngineApiFacade._base_request_defaults() + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + return self.api_facade._normalize_engine_request( + payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + return EngineApiFacade._normalize_streaming_mode(req) + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return self.api_facade._select_direct_backend(normalized) + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + return self.api_facade._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + return self.api_facade._should_use_scheduler_backend_for_direct(req) + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + return self.api_facade._segment_direct_text(normalized) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_scheduler(normalized) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return self.api_facade._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py index 536efbc5..d7740a52 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py @@ -1,310 +1,21 @@ from __future__ import annotations -import asyncio -import time -from typing import Any, Dict, List, Optional +from typing import Any -import numpy as np - -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineDispatchTask, EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_registry import EngineRegistryBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_runtime import EngineRuntimeBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_stage import EngineStageBridgeFacade class EngineBridgeFacade: def __init__(self, owner: Any) -> None: self.owner = owner + self.registry_bridge = EngineRegistryBridgeFacade(owner) + self.stage_bridge = EngineStageBridgeFacade(owner) + self.runtime_bridge = EngineRuntimeBridgeFacade(owner) - @property - def request_registry(self): - return self.owner.request_registry - - @property - def engine_prepare_queue_owner(self): - return self.owner.engine_prepare_queue_owner - - @property - def engine_finalize_queue_owner(self): - return self.owner.engine_finalize_queue_owner - - @property - def engine_dispatch_queue_owner(self): - return self.owner.engine_dispatch_queue_owner - - @property - def engine_decode_runtime_owner(self): - return self.owner.engine_decode_runtime_owner - - @property - def engine_job_registry(self): - return self.owner.engine_job_registry - - @property - def scheduler_worker(self): - return self.owner.scheduler_worker - - @property - def engine_stage_coordinator(self): - return self.owner.engine_stage_coordinator - - @property - def engine_policy_arbiter(self): - return self.owner.engine_policy_arbiter - - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.request_registry.register( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state( - self, - request_id: str, - status: str, - extra: Optional[Dict[str, Any]] = None, - ) -> None: - self.request_registry.update(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.merge_profile(request_id, extra) - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.request_registry.complete(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.request_registry.fail(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.request_registry.snapshot() - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.engine_dispatch_queue_owner.snapshot( - max_request_ids=16, - extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, - ) - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.engine_job_registry.register(job, keep_job=True) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.get(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.engine_job_registry.pop(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.engine_job_registry.snapshot(max_request_ids=32) - - def _is_engine_drained(self) -> bool: - prepare_empty = self.engine_prepare_queue_owner.is_drained() - dispatch_empty = self.engine_dispatch_queue_owner.is_drained() - finalize_empty = self.engine_finalize_queue_owner.is_drained() - decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() - job_empty = self.engine_job_registry.is_empty() - worker_state = self.scheduler_worker.snapshot() - return bool( - prepare_empty - and dispatch_empty - and finalize_empty - and decode_pending_empty - and job_empty - and self.engine_decode_runtime_owner.get_active_batch() is None - and int(worker_state.get("prepare_inflight", 0)) <= 0 - and int(worker_state.get("finalize_inflight", 0)) <= 0 - and int(worker_state.get("finalize_pending", 0)) <= 0 - ) - - def _record_engine_job_done(self, request_id: str) -> None: - self.engine_job_registry.mark_finished_and_remove(request_id) - self.scheduler_worker.record_external_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - completion_bridge = self.scheduler_worker.completion_bridge - completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - completion_bridge.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), - on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), - ) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - if not request_ids: - return - completion_bridge = self.scheduler_worker.completion_bridge - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - completion_bridge.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), - ) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for job in jobs: - job.prefill_ms += delta_ms - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - activate_request_ids: List[str] = [] - for request_id in request_ids: - job = self._get_engine_job(request_id) - if job is None: - continue - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] - self._enqueue_worker_finished_for_finalize(tasks) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.engine_decode_runtime_owner.refresh_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - if self.scheduler_worker.is_engine_decode_control_enabled(): - return - self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.engine_decode_runtime_owner.snapshot_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.engine_policy_arbiter.snapshot_state() - - def _notify_engine_arbiter(self) -> None: - self.engine_policy_arbiter.notify() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.engine_stage_coordinator.decode_runtime_owner.enqueue_pending_job(job) - self._notify_engine_arbiter() - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.engine_stage_coordinator.decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) - - def _engine_has_pending_work(self) -> bool: - return self.engine_stage_coordinator.has_pending_work() - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=prepare_submit_at, - engine_request_id=engine_request_id, - ) - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=speed_factor, - sample_steps=sample_steps, - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id, - timeout_sec=timeout_sec, - ) - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() - self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) - return stage, reason, policy_snapshot, worker_state - - def _run_engine_prepare_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_prepare_once() - - def _run_engine_finalize_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_finalize_once() - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) - - def _run_engine_decode_runtime_once(self) -> bool: - return self.engine_stage_coordinator.run_engine_decode_runtime_once() - - def _run_engine_arbiter_loop(self) -> None: - self.engine_stage_coordinator.run_engine_arbiter_loop() + def __getattr__(self, name: str) -> Any: + for component in (self.registry_bridge, self.stage_bridge, self.runtime_bridge): + if hasattr(component, name): + return getattr(component, name) + raise AttributeError(name) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py new file mode 100644 index 00000000..92714750 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineRequestState, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineBridgeDelegates: + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.bridge_facade._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._merge_request_state_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_prepare_state() + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_finalize_state() + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_dispatch_state() + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._register_engine_job(job) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._get_engine_job(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._pop_engine_job(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_job_registry() + + def _is_engine_drained(self) -> bool: + return self.bridge_facade._is_engine_drained() + + def _record_engine_job_done(self, request_id: str) -> None: + self.bridge_facade._record_engine_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + self.bridge_facade._fail_engine_jobs(request_ids, error) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + self.bridge_facade._enqueue_engine_finished_items(items) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineBridgeFacade._summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.bridge_facade._refresh_engine_decode_runtime_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + self.bridge_facade._update_engine_decode_runtime_state(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_runtime_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_arbiter_state() + + def _notify_engine_arbiter(self) -> None: + self.bridge_facade._notify_engine_arbiter() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._enqueue_engine_decode_pending_job(job) + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.bridge_facade._peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.bridge_facade._engine_has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.bridge_facade._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.bridge_facade._take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.bridge_facade._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + return self.bridge_facade._select_engine_stage() + + def _run_engine_prepare_once(self) -> bool: + return self.bridge_facade._run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.bridge_facade._run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.bridge_facade._run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.bridge_facade._run_engine_arbiter_loop() + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.bridge_facade._fail_request_state(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_request_registry() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py new file mode 100644 index 00000000..88b8cc5d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineRegistryBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def request_registry(self): + return self.owner.request_registry + + @property + def engine_prepare_queue_owner(self): + return self.owner.engine_prepare_queue_owner + + @property + def engine_finalize_queue_owner(self): + return self.owner.engine_finalize_queue_owner + + @property + def engine_dispatch_queue_owner(self): + return self.owner.engine_dispatch_queue_owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_job_registry(self): + return self.owner.engine_job_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.request_registry.register( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.request_registry.update(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.merge_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.complete(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.request_registry.fail(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.request_registry.snapshot() + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self.owner.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py new file mode 100644 index 00000000..47be8b67 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner + + +class EngineRuntimeBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py new file mode 100644 index 00000000..29b5aaab --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineStageBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_stage_coordinator(self): + return self.owner.engine_stage_coordinator + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_decode_runtime_owner.enqueue_pending_job(job) + self.owner.engine_policy_arbiter.notify() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.engine_stage_coordinator.has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _run_engine_prepare_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + return self.engine_stage_coordinator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py index f68d3ede..d60a3bb8 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -1,401 +1,9 @@ -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple - -import numpy as np - -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState -from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade -from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineDispatchTask, EngineRequestState, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerFinalizeTask, SchedulerPendingJob, SchedulerSubmitExecution -from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade - - -class EngineBridgeDelegates: - def _register_request_state( - self, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - return self.bridge_facade._register_request_state( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=response_streaming, - deadline_ts=deadline_ts, - meta=meta, - ) - - def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._update_request_state(request_id, status, extra) - - def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._merge_request_state_profile(request_id, extra) - - def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_prepare_state() - - def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_finalize_state() - - def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_dispatch_state() - - def _register_engine_job(self, job: SchedulerPendingJob) -> None: - self.bridge_facade._register_engine_job(job) - - def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.bridge_facade._get_engine_job(request_id) - - def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: - return self.bridge_facade._pop_engine_job(request_id) - - def _snapshot_engine_job_registry(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_job_registry() - - def _is_engine_drained(self) -> bool: - return self.bridge_facade._is_engine_drained() - - def _record_engine_job_done(self, request_id: str) -> None: - self.bridge_facade._record_engine_job_done(request_id) - - def _complete_engine_job( - self, - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - - def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: - self.bridge_facade._fail_engine_jobs(request_ids, error) - - def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: - self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) - - def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) - - def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) - - def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: - self.bridge_facade._enqueue_engine_finished_items(items) - - def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_decode_pending_queue_state() - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - return EngineBridgeFacade._summarize_active_batch(active_batch) - - def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: - self.bridge_facade._refresh_engine_decode_runtime_state(last_event) - - def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: - self.bridge_facade._update_engine_decode_runtime_state(snapshot) - - def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_decode_runtime_state() - - def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_engine_arbiter_state() - - def _notify_engine_arbiter(self) -> None: - self.bridge_facade._notify_engine_arbiter() - - def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: - self.bridge_facade._enqueue_engine_decode_pending_job(job) - - def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) - - def _peek_queue_age_ms(self, queue_name: str) -> float: - return self.bridge_facade._peek_queue_age_ms(queue_name) - - def _engine_has_pending_work(self) -> bool: - return self.bridge_facade._engine_has_pending_work() - - async def _prepare_state_via_engine_gpu_queue( - self, - *, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - return await self.bridge_facade._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=prepare_submit_at, - engine_request_id=engine_request_id, - ) - - def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) - - def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - return self.bridge_facade._take_engine_finalize_batch_nonblocking() - - async def _enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - return await self.bridge_facade._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=speed_factor, - sample_steps=sample_steps, - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id, - timeout_sec=timeout_sec, - ) - - def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) - - def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - return self.bridge_facade._select_engine_stage() - - def _run_engine_prepare_once(self) -> bool: - return self.bridge_facade._run_engine_prepare_once() - - def _run_engine_finalize_once(self) -> bool: - return self.bridge_facade._run_engine_finalize_once() - - def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) - - def _run_engine_decode_runtime_once(self) -> bool: - return self.bridge_facade._run_engine_decode_runtime_once() - - def _run_engine_arbiter_loop(self) -> None: - self.bridge_facade._run_engine_arbiter_loop() - - def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - self.bridge_facade._complete_request_state(request_id, extra) - - def _fail_request_state(self, request_id: str, error: str) -> None: - self.bridge_facade._fail_request_state(request_id, error) - - def _snapshot_request_registry(self) -> Dict[str, Any]: - return self.bridge_facade._snapshot_request_registry() - - -class EngineApiDelegates: - def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - return self.api_facade._collect_request_summaries(request_ids) - - def _has_active_request(self, request_id: str) -> bool: - return self.api_facade._has_active_request(request_id) - - @staticmethod - def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: - return EngineApiFacade._build_request_meta(payload) - - @staticmethod - def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: - return EngineApiFacade._sum_profile_field(items, key) - - def _build_direct_segment_trace( - self, - segment_texts: Sequence[str], - prepare_profiles: Sequence[Dict[str, Any]], - worker_profiles: Sequence[Dict[str, Any]], - ) -> List[Dict[str, Any]]: - return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) - - def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_direct_scheduler_profile(**kwargs) - - def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_legacy_direct_profile(**kwargs) - - def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_scheduler_submit_profile(**kwargs) - - @staticmethod - def _format_ms_header(value: Any) -> str: - return EngineApiFacade._format_ms_header(value) - - def _build_scheduler_submit_headers( - self, - *, - request_id: str, - media_type: str, - sample_rate: int, - profile: Dict[str, Any], - ) -> Dict[str, str]: - return self.api_facade._build_scheduler_submit_headers( - request_id=request_id, - media_type=media_type, - sample_rate=sample_rate, - profile=profile, - ) - - def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: - return self.api_facade._build_scheduler_debug_request_profile(**kwargs) - - @staticmethod - def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: - return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) - - def _normalize_lang(self, value: str | None) -> str | None: - return self.api_facade._normalize_lang(value) - - @staticmethod - def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: - return EngineApiFacade._aggregate_numeric_dicts(items) - - def _apply_default_reference(self, req: dict) -> dict: - return self.api_facade._apply_default_reference(req) - - def check_params(self, req: dict) -> Optional[str]: - return self.api_facade.check_params(req) - - @staticmethod - def _base_request_defaults() -> Dict[str, Any]: - return EngineApiFacade._base_request_defaults() - - def _normalize_engine_request( - self, - payload: dict | NormalizedEngineRequest, - *, - request_id: str | None = None, - normalize_streaming: bool = False, - error_prefix: str = "request 参数非法: ", - ) -> NormalizedEngineRequest: - return self.api_facade._normalize_engine_request( - payload, - request_id=request_id, - normalize_streaming=normalize_streaming, - error_prefix=error_prefix, - ) - - @staticmethod - def _normalize_streaming_mode(req: dict) -> dict: - return EngineApiFacade._normalize_streaming_mode(req) - - @staticmethod - def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: - return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) - - def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: - return self.api_facade._select_direct_backend(normalized) - - def _iter_legacy_direct_tts_bytes( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> Generator[bytes, None, None]: - return self.api_facade._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: - return self.api_facade._should_use_scheduler_backend_for_direct(req) - - def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: - return self.api_facade._segment_direct_text(normalized) - - def _build_segment_request( - self, - normalized: NormalizedEngineRequest, - *, - request_id: str, - text: str, - ) -> NormalizedEngineRequest: - return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) - - async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: - return await self.api_facade._run_direct_tts_via_scheduler(normalized) - - def _run_legacy_direct_tts_blocking( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - return self.api_facade._run_legacy_direct_tts_blocking( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - - async def _run_direct_tts_via_legacy_backend( - self, - normalized: NormalizedEngineRequest, - *, - backend: str, - fallback_reason: str | None, - ) -> DirectTTSExecution: - return await self.api_facade._run_direct_tts_via_legacy_backend( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - -class EngineRuntimeDelegates: - @staticmethod - def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: - return EngineRuntimeFacade._safe_component_snapshot(component) - - def _build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_stage_counters(request_registry, worker_state) - - def _build_engine_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) - - async def _wait_for_engine_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - return await self.engine_policy_arbiter.wait_for_policy_admission( - request_id=request_id, - timeout_sec=timeout_sec, - ) - - def _build_stage_summary( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - return self.runtime_facade._build_stage_summary(request_registry, worker_state) - - def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: - self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_delegates import EngineApiDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_delegates import EngineBridgeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime_delegates import EngineRuntimeDelegates + +__all__ = [ + "EngineApiDelegates", + "EngineBridgeDelegates", + "EngineRuntimeDelegates", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py new file mode 100644 index 00000000..a71f7e4e --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineTaskQueueOwner +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageOrchestrator: + def __init__( + self, + *, + executor: EngineStageExecutor, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.executor = executor + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None + self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None + self._wait_arbiter: Callable[[], None] | None = None + + def bind_arbiter( + self, + *, + notify_arbiter: Callable[[], None], + select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]], + mark_arbiter_tick: Callable[[str, str, bool], None], + wait_arbiter: Callable[[], None], + ) -> None: + self.executor.bind_notify_arbiter(notify_arbiter) + self._select_stage = select_stage + self._mark_arbiter_tick = mark_arbiter_tick + self._wait_arbiter = wait_arbiter + + def peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "finalize": + return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + if queue_name == "decode_runtime_pending": + return self.decode_runtime_owner.pending_age_ms() + return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( + "active_request_count", 0 + ) > 0: + return True + if self.prepare_queue_owner.has_items(): + return True + if self.finalize_queue_owner.has_items(): + return True + return self.dispatch_queue_owner.has_items() + + def run_engine_arbiter_loop(self) -> None: + if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: + raise RuntimeError("arbiter callbacks are not bound") + while True: + if not self.has_pending_work(): + self._mark_arbiter_tick("idle", "no_pending_work", True) + self._wait_arbiter() + continue + stage, reason, policy_snapshot, worker_state = self._select_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self.executor.run_engine_prepare_once() + elif stage == "finalize": + executed = self.executor.run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self.executor.run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) + self._wait_arbiter() + continue + self._mark_arbiter_tick(stage, reason, policy_allowed) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py new file mode 100644 index 00000000..96153196 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade + + +class EngineRuntimeDelegates: + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + return EngineRuntimeFacade._safe_component_snapshot(component) + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_summary(request_registry, worker_state) + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py index 65f0befe..9aad2fb8 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -1,20 +1,19 @@ from __future__ import annotations import asyncio -import time -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( EngineDecodeRuntimeOwner, EngineDispatchTask, - EngineGpuPrepareTask, - EngineStatus, EngineTaskQueueOwner, SchedulerFinalizeTask, SchedulerPendingJob, ) +from GPT_SoVITS.TTS_infer_pack.unified_engine_orchestration import EngineStageOrchestrator +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker @@ -42,29 +41,36 @@ class EngineStageCoordinator: snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], ) -> None: - self.tts = tts - self.scheduler_worker = scheduler_worker - self.prepare_queue_owner = prepare_queue_owner - self.finalize_queue_owner = finalize_queue_owner - self.dispatch_queue_owner = dispatch_queue_owner - self.decode_runtime_owner = decode_runtime_owner - self.update_request_state = update_request_state - self.merge_request_state_profile = merge_request_state_profile - self.fail_request_state = fail_request_state - self.get_engine_job = get_engine_job - self.register_engine_job = register_engine_job - self.fail_engine_jobs = fail_engine_jobs - self.complete_engine_job = complete_engine_job - self.add_engine_prefill_time = add_engine_prefill_time - self.add_engine_merge_time = add_engine_merge_time - self.add_engine_decode_time = add_engine_decode_time - self.enqueue_engine_finished_items = enqueue_engine_finished_items - self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state - self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state - self._notify_arbiter: Callable[[], None] | None = None - self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None - self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None - self._wait_arbiter: Callable[[], None] | None = None + self.executor = EngineStageExecutor( + tts=tts, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + update_request_state=update_request_state, + merge_request_state_profile=merge_request_state_profile, + fail_request_state=fail_request_state, + get_engine_job=get_engine_job, + register_engine_job=register_engine_job, + fail_engine_jobs=fail_engine_jobs, + complete_engine_job=complete_engine_job, + add_engine_prefill_time=add_engine_prefill_time, + add_engine_merge_time=add_engine_merge_time, + add_engine_decode_time=add_engine_decode_time, + enqueue_engine_finished_items=enqueue_engine_finished_items, + snapshot_engine_dispatch_state=snapshot_engine_dispatch_state, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) + self.orchestrator = EngineStageOrchestrator( + executor=self.executor, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) def bind_arbiter( self, @@ -74,57 +80,12 @@ class EngineStageCoordinator: mark_arbiter_tick: Callable[[str, str, bool], None], wait_arbiter: Callable[[], None], ) -> None: - self._notify_arbiter = notify_arbiter - self._select_stage = select_stage - self._mark_arbiter_tick = mark_arbiter_tick - self._wait_arbiter = wait_arbiter - - def notify_arbiter(self) -> None: - if self._notify_arbiter is not None: - self._notify_arbiter() - - @staticmethod - def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: - if future.done(): - return - future.set_exception(error) - - def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - @staticmethod - def _resolve_prepare_future( - future: asyncio.Future, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if future.done(): - return - future.set_result(payload) - - def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - def _notify_prepare_result( - self, - task: EngineGpuPrepareTask, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) - except RuntimeError: - pass + self.orchestrator.bind_arbiter( + notify_arbiter=notify_arbiter, + select_stage=select_stage, + mark_arbiter_tick=mark_arbiter_tick, + wait_arbiter=wait_arbiter, + ) async def prepare_state_via_engine_gpu_queue( self, @@ -133,59 +94,17 @@ class EngineStageCoordinator: prepare_submit_at: float, engine_request_id: str | None, ) -> tuple[T2SRequestState, float, float]: - cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - if engine_request_id not in [None, ""]: - self.update_request_state( - str(engine_request_id), - EngineStatus.GPU_PREPARING, - { - "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), - "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), - "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), - "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), - }, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - task = EngineGpuPrepareTask( - request_id=spec.request_id, - cpu_stage=cpu_stage, - done_loop=loop, - done_future=done_future, - engine_request_id=engine_request_id or spec.request_id, - enqueue_time=time.perf_counter(), + return await self.executor.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, ) - self.prepare_queue_owner.enqueue(task) - self.notify_arbiter() - state, prepare_exec_started_at, prepare_exec_finished_at = await done_future - return state, prepare_exec_started_at, prepare_exec_finished_at def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - self.update_request_state( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": task.item.finish_reason, - "semantic_len": int(task.item.semantic_tokens.shape[0]), - "finish_idx": int(task.item.finish_idx), - }, - ) - self.finalize_queue_owner.enqueue_many(tasks) - self.notify_arbiter() + self.executor.enqueue_worker_finished_for_finalize(tasks) def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - finalize_policy = self.scheduler_worker.get_finalize_batch_policy() - return self.finalize_queue_owner.take_finalize_batch( - finalize_mode=str(finalize_policy.get("finalize_mode", "async")), - batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), - batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), - use_vocoder=bool(self.tts.configs.use_vocoder), - ) + return self.executor.take_engine_finalize_batch_nonblocking() async def enqueue_prepared_state_for_dispatch( self, @@ -201,220 +120,36 @@ class EngineStageCoordinator: engine_request_id: str | None, timeout_sec: float | None, ) -> EngineDispatchTask: - task = EngineDispatchTask( - request_id=state.request_id, + return await self.executor.enqueue_prepared_state_for_dispatch( state=state, - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), + speed_factor=speed_factor, + sample_steps=sample_steps, media_type=media_type, - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, done_loop=done_loop, done_future=done_future, - engine_request_id=engine_request_id or state.request_id, + engine_request_id=engine_request_id, timeout_sec=timeout_sec, - enqueue_time=time.perf_counter(), ) - self.dispatch_queue_owner.enqueue(task) - self.notify_arbiter() - self.merge_request_state_profile( - task.engine_request_id or task.request_id, - { - "engine_dispatch_queue_depth_on_enqueue": int( - self.snapshot_engine_dispatch_state()["waiting_count"] - ), - }, - ) - return task def peek_queue_age_ms(self, queue_name: str) -> float: - if queue_name == "prepare": - return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") - if queue_name == "finalize": - return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") - if queue_name == "decode_runtime_pending": - return self.decode_runtime_owner.pending_age_ms() - return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + return self.orchestrator.peek_queue_age_ms(queue_name) def has_pending_work(self) -> bool: - if self.scheduler_worker.is_engine_decode_control_enabled(): - if self.decode_runtime_owner.has_pending_jobs(): - return True - if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( - "active_request_count", 0 - ) > 0: - return True - if self.prepare_queue_owner.has_items(): - return True - if self.finalize_queue_owner.has_items(): - return True - return self.dispatch_queue_owner.has_items() + return self.orchestrator.has_pending_work() def run_engine_prepare_once(self) -> bool: - task = self.prepare_queue_owner.pop_left() - if task is None: - return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) - if task.engine_request_id not in [None, ""]: - self.merge_request_state_profile( - str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, - ) - self.prepare_queue_owner.mark_completed(1) - self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True + return self.executor.run_engine_prepare_once() def run_engine_finalize_once(self) -> bool: - tasks = self.take_engine_finalize_batch_nonblocking() - if not tasks: - return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return False - now = time.perf_counter() - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) - for job, item in jobs_and_items: - self.update_request_state( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) - for job, _ in jobs_and_items: - job.synth_ms += float(synth_ms) - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) - finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.finalize_queue_owner.mark_completed(len(tasks), notify=True) - return True + return self.executor.run_engine_finalize_once() def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - if not bool(policy_snapshot.get("allowed", True)): - return False - dispatch_task = self.dispatch_queue_owner.pop_left() - if dispatch_task is None: - return False - dispatched_at = time.perf_counter() - dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) - dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_policy_snapshot = dict(policy_snapshot) - try: - worker_job = self.scheduler_worker.submit( - state=dispatch_task.state, - speed_factor=dispatch_task.speed_factor, - sample_steps=dispatch_task.sample_steps, - media_type=dispatch_task.media_type, - prepare_wall_ms=dispatch_task.prepare_wall_ms, - prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, - done_loop=dispatch_task.done_loop, - done_future=dispatch_task.done_future, - engine_request_id=dispatch_task.engine_request_id, - timeout_sec=dispatch_task.timeout_sec, - skip_capacity_wait=True, - admission_wait_ms_override=0.0, - admission_snapshot_override=dict(worker_state), - engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, - engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, - enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), - ) - dispatch_task.worker_job = worker_job - self.register_engine_job(worker_job) - if self.scheduler_worker.is_engine_decode_control_enabled(): - self.decode_runtime_owner.enqueue_pending_job(worker_job) - self.notify_arbiter() - self.dispatch_queue_owner.mark_completed(1) - return True - except Exception as exc: - dispatch_task.error = str(exc) - self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) - self._notify_dispatch_error(dispatch_task, exc) - return True + return self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) def run_engine_decode_runtime_once(self) -> bool: - if not self.scheduler_worker.is_engine_decode_control_enabled(): - return False - runtime_state = self.snapshot_engine_decode_runtime_state() - pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( - wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 - ) - result = self.scheduler_worker.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=self.decode_runtime_owner.get_active_batch(), - external_bookkeeping=True, - ) - prefill_phase = dict(result.get("prefill_phase") or {}) - if prefill_phase.get("error"): - self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) - else: - prefill_jobs = list(prefill_phase.get("pending_jobs") or []) - self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) - self.add_engine_merge_time( - [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), - float(prefill_phase.get("merge_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) - decode_phase = dict(result.get("decode_phase") or {}) - if decode_phase.get("error"): - self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) - else: - self.add_engine_decode_time( - list(decode_phase.get("request_ids") or []), - float(decode_phase.get("decode_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) - self.decode_runtime_owner.set_active_batch(result.get("active_batch")) - if result.get("executed", False): - self.decode_runtime_owner.refresh_state("engine_decode_cycle") - return bool(result.get("executed", False)) + return self.executor.run_engine_decode_runtime_once() def run_engine_arbiter_loop(self) -> None: - if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: - raise RuntimeError("arbiter callbacks are not bound") - while True: - if not self.has_pending_work(): - self._mark_arbiter_tick("idle", "no_pending_work", True) - self._wait_arbiter() - continue - stage, reason, policy_snapshot, worker_state = self._select_stage() - policy_allowed = bool(policy_snapshot.get("allowed", True)) - executed = False - if stage == "prepare": - executed = self.run_engine_prepare_once() - elif stage == "finalize": - executed = self.run_engine_finalize_once() - elif stage == "decode_dispatch": - executed = self.run_engine_dispatch_once(policy_snapshot, worker_state) - elif stage == "decode_runtime": - executed = self.run_engine_decode_runtime_once() - if not executed: - self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) - self._wait_arbiter() - continue - self._mark_arbiter_tick(stage, reason, policy_allowed) + self.orchestrator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py new file mode 100644 index 00000000..77274056 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineDecodeRuntimeOwner, + EngineDispatchTask, + EngineGpuPrepareTask, + EngineStatus, + EngineTaskQueueOwner, + SchedulerFinalizeTask, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageExecutor: + def __init__( + self, + *, + tts: TTS, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + fail_request_state: Callable[[str, str], None], + get_engine_job: Callable[[str], SchedulerPendingJob | None], + register_engine_job: Callable[[SchedulerPendingJob], None], + fail_engine_jobs: Callable[[List[str], str], None], + complete_engine_job: Callable[..., None], + add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None], + add_engine_merge_time: Callable[[List[str], float], None], + add_engine_decode_time: Callable[[List[str], float], None], + enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None], + snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.tts = tts + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.update_request_state = update_request_state + self.merge_request_state_profile = merge_request_state_profile + self.fail_request_state = fail_request_state + self.get_engine_job = get_engine_job + self.register_engine_job = register_engine_job + self.fail_engine_jobs = fail_engine_jobs + self.complete_engine_job = complete_engine_job + self.add_engine_prefill_time = add_engine_prefill_time + self.add_engine_merge_time = add_engine_merge_time + self.add_engine_decode_time = add_engine_decode_time + self.enqueue_engine_finished_items = enqueue_engine_finished_items + self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._notify_arbiter: Callable[[], None] | None = None + + def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None: + self._notify_arbiter = notify_arbiter + + def notify_arbiter(self) -> None: + if self._notify_arbiter is not None: + self._notify_arbiter() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass + + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec: Any, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self.update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.prepare_queue_owner.enqueue(task) + self.notify_arbiter() + return await done_future + + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + self.update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.finalize_queue_owner.enqueue_many(tasks) + self.notify_arbiter() + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.dispatch_queue_owner.enqueue(task) + self.notify_arbiter() + self.merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int( + self.snapshot_engine_dispatch_state()["waiting_count"] + ), + }, + ) + return task + + def run_engine_prepare_once(self) -> bool: + task = self.prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True + + def run_engine_finalize_once(self) -> bool: + tasks = self.take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self.update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self.register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self.decode_runtime_owner.enqueue_pending_job(worker_job) + self.notify_arbiter() + self.dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True + + def run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self.snapshot_engine_decode_runtime_state() + pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self.add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self.add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self.decode_runtime_owner.refresh_state("engine_decode_cycle") + return bool(result.get("executed", False)) From 3fd4f48651e959388dd5215bd25d3c3e20969858 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 18:36:24 +0800 Subject: [PATCH 16/24] Add unified engine API modules for direct and scheduler-based TTS processing Introduce new modules including unified_engine_api_direct, unified_engine_api_profile, unified_engine_api_request, and unified_engine_api_scheduler. These additions enhance the TTS system by providing structured interfaces for direct TTS execution and scheduler-based processing. The new components support improved request handling, profiling, and state management, significantly enhancing the architecture and maintainability of the TTS framework. --- .../TTS_infer_pack/unified_engine_api.py | 1222 ++--------------- .../unified_engine_api_direct.py | 441 ++++++ .../unified_engine_api_profile.py | 387 ++++++ .../unified_engine_api_request.py | 199 +++ .../unified_engine_api_scheduler.py | 283 ++++ 5 files changed, 1446 insertions(+), 1086 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py index ca76252d..ca372d5d 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -1,21 +1,37 @@ from __future__ import annotations -import asyncio -import time -import uuid -from io import BytesIO from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple -import numpy as np - -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous -from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed, wave_header_chunk +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_direct import EngineApiDirectFlow +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_profile import ( + aggregate_numeric_dicts, + build_direct_scheduler_profile, + build_direct_segment_trace, + build_legacy_direct_profile, + build_request_meta, + build_scheduler_debug_batch_profile, + build_scheduler_debug_request_profile, + build_scheduler_submit_headers, + build_scheduler_submit_profile, + format_ms_header, + sum_profile_field, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_request import ( + apply_default_reference, + base_request_defaults, + check_params, + is_aux_ref_enabled, + normalize_engine_request, + normalize_lang, + normalize_streaming_mode, + select_direct_backend, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_scheduler import EngineApiSchedulerFlow +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( DirectTTSExecution, - EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, - SchedulerPendingJob, SchedulerSubmitExecution, ) @@ -23,6 +39,8 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( class EngineApiFacade: def __init__(self, owner: Any) -> None: self.owner = owner + self.direct_flow = EngineApiDirectFlow(self) + self.scheduler_flow = EngineApiSchedulerFlow(self) @property def tts(self): @@ -129,24 +147,11 @@ class EngineApiFacade: @staticmethod def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: - text = payload.get("text") - prompt_text = payload.get("prompt_text") - return { - "text_len": 0 if text is None else len(str(text)), - "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), - "text_lang": payload.get("text_lang"), - "prompt_lang": payload.get("prompt_lang"), - "ref_audio_path": payload.get("ref_audio_path"), - } + return build_request_meta(payload) @staticmethod def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: - total = 0.0 - for item in items: - value = item.get(key, 0.0) - if isinstance(value, (int, float)): - total += float(value) - return total + return sum_profile_field(items, key) def _build_direct_segment_trace( self, @@ -154,38 +159,7 @@ class EngineApiFacade: prepare_profiles: Sequence[Dict[str, Any]], worker_profiles: Sequence[Dict[str, Any]], ) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - for index, segment_text in enumerate(segment_texts): - prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} - worker_item = worker_profiles[index] if index < len(worker_profiles) else {} - prepare_profile = dict(prepare_item.get("prepare_profile", {})) - results.append( - { - "segment_index": index, - "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), - "text_len": len(str(segment_text)), - "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), - "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), - "prepare_engine_gpu_queue_wait_ms": float( - dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) - ), - "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), - "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), - "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), - "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), - "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), - "merge_ms": float(worker_item.get("merge_ms", 0.0)), - "decode_ms": float(worker_item.get("decode_ms", 0.0)), - "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), - "synth_ms": float(worker_item.get("synth_ms", 0.0)), - "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), - "decode_steps": int(worker_item.get("decode_steps", 0)), - "semantic_len": int(worker_item.get("semantic_len", 0)), - "finish_reason": worker_item.get("finish_reason"), - "norm_text": prepare_profile.get("norm_text"), - } - ) - return results + return build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) def _build_direct_scheduler_profile( self, @@ -201,57 +175,18 @@ class EngineApiFacade: pack_ms: float, response_overhead_ms: float, ) -> Dict[str, Any]: - segment_trace = self._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) - prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - prepare_wall_ms = self._sum_profile_field(prepare_profiles, "prepare_wall_ms") - prepare_profile_total_ms = self._sum_profile_field(prepare_profiles, "prepare_profile_total_ms") - engine_policy_wait_ms = self._sum_profile_field(prepare_profiles, "engine_policy_wait_ms") - engine_dispatch_wait_ms = self._sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") - decode_admission_wait_ms = self._sum_profile_field(worker_profiles, "decode_admission_wait_ms") - queue_wait_ms = self._sum_profile_field(worker_profiles, "queue_wait_ms") - prefill_ms = self._sum_profile_field(worker_profiles, "prefill_ms") - merge_ms = self._sum_profile_field(worker_profiles, "merge_ms") - decode_ms = self._sum_profile_field(worker_profiles, "decode_ms") - finalize_wait_ms = self._sum_profile_field(worker_profiles, "finalize_wait_ms") - synth_ms = self._sum_profile_field(worker_profiles, "synth_ms") - worker_total_ms = self._sum_profile_field(worker_profiles, "worker_total_ms") - decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) - semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) - request_other_ms = max( - 0.0, - request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, + return build_direct_scheduler_profile( + backend=backend, + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=audio_bytes, + sample_rate=sample_rate, + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, ) - return { - "backend": backend, - "backend_mode": backend, - "segment_count": len(segment_texts), - "sample_rate": int(sample_rate), - "audio_bytes": int(audio_bytes), - "request_total_ms": request_total_ms, - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "engine_policy_wait_ms": engine_policy_wait_ms, - "engine_dispatch_wait_ms": engine_dispatch_wait_ms, - "decode_admission_wait_ms": decode_admission_wait_ms, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": prefill_ms, - "merge_ms": merge_ms, - "decode_ms": decode_ms, - "finalize_wait_ms": finalize_wait_ms, - "synth_ms": synth_ms, - "pack_ms": pack_ms, - "response_overhead_ms": response_overhead_ms, - "worker_total_ms": worker_total_ms, - "request_other_ms": request_other_ms, - "decode_steps": decode_steps, - "semantic_len": semantic_len, - "prepare_segments": list(prepare_profiles), - "worker_segments": list(worker_profiles), - "segment_trace": segment_trace, - "prepare_aggregate": self._aggregate_numeric_dicts(prepare_profile_dicts), - } def _build_legacy_direct_profile( self, @@ -267,30 +202,18 @@ class EngineApiFacade: stream_total_bytes: int = 0, first_chunk_ms: float | None = None, ) -> Dict[str, Any]: - request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) - legacy_infer_ms = max(0.0, request_total_ms - pack_ms) - return { - "backend": backend, - "backend_mode": backend, - "fallback_reason": fallback_reason, - "request_total_ms": request_total_ms, - "prepare_ms": 0.0, - "queue_wait_ms": 0.0, - "prefill_ms": 0.0, - "merge_ms": 0.0, - "decode_ms": 0.0, - "finalize_wait_ms": 0.0, - "synth_ms": 0.0, - "pack_ms": pack_ms, - "worker_total_ms": legacy_infer_ms, - "request_other_ms": 0.0, - "legacy_infer_ms": legacy_infer_ms, - "sample_rate": int(sample_rate) if sample_rate is not None else None, - "audio_bytes": int(audio_bytes), - "chunk_count": int(chunk_count), - "stream_total_bytes": int(stream_total_bytes), - "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), - } + return build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=finished_at, + sample_rate=sample_rate, + audio_bytes=audio_bytes, + pack_ms=pack_ms, + chunk_count=chunk_count, + stream_total_bytes=stream_total_bytes, + first_chunk_ms=first_chunk_ms, + ) def _build_scheduler_submit_profile( self, @@ -314,45 +237,30 @@ class EngineApiFacade: response_overhead_ms: float, worker_profile: Dict[str, Any], ) -> Dict[str, Any]: - worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - request_other_ms = max( - 0.0, - request_total_ms - - prepare_wall_ms - - engine_policy_wait_ms - - api_after_prepare_ms - - worker_total_ms - - api_wait_result_ms - - pack_ms, + return build_scheduler_submit_profile( + backend=backend, + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=audio_bytes, + sample_rate=sample_rate, + prepare_spec_build_ms=prepare_spec_build_ms, + prepare_wall_ms=prepare_wall_ms, + prepare_executor_queue_ms=prepare_executor_queue_ms, + prepare_executor_run_ms=prepare_executor_run_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + prepare_profile_wall_ms=prepare_profile_wall_ms, + prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=engine_policy_wait_ms, + api_after_prepare_ms=api_after_prepare_ms, + api_wait_result_ms=api_wait_result_ms, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + worker_profile=worker_profile, ) - result = { - "backend": backend, - "backend_mode": backend, - "audio_bytes": int(audio_bytes), - "sample_rate": int(sample_rate), - "prepare_spec_build_ms": prepare_spec_build_ms, - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_executor_queue_ms": prepare_executor_queue_ms, - "prepare_executor_run_ms": prepare_executor_run_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile_wall_ms": prepare_profile_wall_ms, - "prepare_other_ms": prepare_other_ms, - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "api_after_prepare_ms": api_after_prepare_ms, - "api_wait_result_ms": api_wait_result_ms, - "pack_ms": pack_ms, - "response_overhead_ms": response_overhead_ms, - "request_total_ms": request_total_ms, - "request_other_ms": request_other_ms, - } - result.update({key: value for key, value in worker_profile.items()}) - return result @staticmethod def _format_ms_header(value: Any) -> str: - return f"{float(value):.3f}" + return format_ms_header(value) def _build_scheduler_submit_headers( self, @@ -362,87 +270,12 @@ class EngineApiFacade: sample_rate: int, profile: Dict[str, Any], ) -> Dict[str, str]: - prepare_profile = dict(profile.get("prepare_profile", {})) - headers = { - "X-Request-Id": request_id, - "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), - "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), - "X-Queue-Wait-Ms": self._format_ms_header(profile.get("queue_wait_ms", 0.0)), - "X-Decode-Admission-Wait-Ms": self._format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), - "X-Engine-Policy-Wait-Ms": self._format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), - "X-Engine-Dispatch-Wait-Ms": self._format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), - "X-Prepare-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), - "X-Prepare-Wall-Ms": self._format_ms_header(profile.get("prepare_wall_ms", 0.0)), - "X-Prepare-Spec-Build-Ms": self._format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), - "X-Prepare-Executor-Queue-Ms": self._format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), - "X-Prepare-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), - "X-Prepare-Executor-Run-Ms": self._format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), - "X-Prepare-Profile-Total-Ms": self._format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), - "X-Prepare-Profile-Wall-Ms": self._format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), - "X-Prepare-Other-Ms": self._format_ms_header(profile.get("prepare_other_ms", 0.0)), - "X-Api-After-Prepare-Ms": self._format_ms_header(profile.get("api_after_prepare_ms", 0.0)), - "X-Prefill-Ms": self._format_ms_header(profile.get("prefill_ms", 0.0)), - "X-Merge-Ms": self._format_ms_header(profile.get("merge_ms", 0.0)), - "X-Decode-Ms": self._format_ms_header(profile.get("decode_ms", 0.0)), - "X-Finalize-Wait-Ms": self._format_ms_header(profile.get("finalize_wait_ms", 0.0)), - "X-Synth-Ms": self._format_ms_header(profile.get("synth_ms", 0.0)), - "X-Worker-Residual-Ms": self._format_ms_header(profile.get("worker_residual_ms", 0.0)), - "X-Worker-Other-Ms": self._format_ms_header(profile.get("worker_other_ms", 0.0)), - "X-Pack-Ms": self._format_ms_header(profile.get("pack_ms", 0.0)), - "X-Worker-Total-Ms": self._format_ms_header(profile.get("worker_total_ms", 0.0)), - "X-Api-Wait-Result-Ms": self._format_ms_header(profile.get("api_wait_result_ms", 0.0)), - "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), - "X-Sample-Rate": str(int(sample_rate)), - "X-Response-Overhead-Ms": self._format_ms_header(profile.get("response_overhead_ms", 0.0)), - "X-Request-Other-Ms": self._format_ms_header(profile.get("request_other_ms", 0.0)), - "X-Request-Total-Ms": self._format_ms_header(profile.get("request_total_ms", 0.0)), - } - headers.update( - { - "X-Prepare-Prompt-Text-Ms": self._format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), - "X-Prepare-Target-Text-Ms": self._format_ms_header(prepare_profile.get("text_features_ms", 0.0)), - "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), - "X-Prepare-Target-Text-CPU-Preprocess-Ms": self._format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), - "X-Prepare-Prompt-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), - "X-Prepare-Target-Text-CPU-Queue-Ms": self._format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), - "X-Prepare-Prompt-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), - "X-Prepare-Target-Text-Feature-Queue-Ms": self._format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), - "X-Prepare-Prompt-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Admission-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), - "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), - "X-Prepare-Prompt-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), - "X-Prepare-Target-Bert-Forward-Ms": self._format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), - "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), - "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), - "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), - "X-Prepare-Prompt-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), - "X-Prepare-Target-Bert-Batch-Window-Ms": self._format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), - "X-Prepare-Text-Pair-Wall-Ms": self._format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), - "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), - "X-Prepare-Engine-GPU-Queue-Wait-Ms": self._format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), - "X-Prepare-Audio-Load-Ms": self._format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), - "X-Prepare-Audio-Stage-Wait-Ms": self._format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Wait-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), - "X-Prepare-Prompt-Semantic-CPU-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), - "X-Prepare-Prompt-Semantic-Forward-Ms": self._format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), - "X-Prepare-Ref-Spec-Ms": self._format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), - "X-Prepare-Ref-Spec-Wait-Ms": self._format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), - "X-Prepare-Ref-Bundle-Ms": self._format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), - "X-Prepare-Tensorize-Ms": self._format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), - "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), - "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), - } + return build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, ) - return headers def _build_scheduler_debug_request_profile( self, @@ -454,26 +287,14 @@ class EngineApiFacade: decode_batch_wall_ms: float, batch_request_total_ms: float, ) -> Dict[str, Any]: - prepare_profile = dict(state.prepare_profile) - prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) - return { - "backend": "scheduler_debug", - "backend_mode": "scheduler_debug", - "batch_request_count": int(batch_request_count), - "batch_prepare_wall_ms": float(prepare_batch_wall_ms), - "batch_decode_wall_ms": float(decode_batch_wall_ms), - "batch_request_total_ms": float(batch_request_total_ms), - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), - "prepare_profile": prepare_profile, - "decode_steps": int(item.finish_idx), - "finish_idx": int(item.finish_idx), - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_reason": item.finish_reason, - "norm_text": state.norm_text, - "norm_prompt_text": state.norm_prompt_text, - } + return build_scheduler_debug_request_profile( + state=state, + item=item, + batch_request_count=batch_request_count, + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + batch_request_total_ms=batch_request_total_ms, + ) @staticmethod def _build_scheduler_debug_batch_profile( @@ -485,107 +306,31 @@ class EngineApiFacade: request_total_ms: float, finished_items: Sequence[T2SFinishedItem], ) -> Dict[str, Any]: - finish_reason_counts: Dict[str, int] = {} - total_semantic_len = 0 - for item in finished_items: - finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 - total_semantic_len += int(item.semantic_tokens.shape[0]) - return { - "request_count": int(request_count), - "max_steps": int(max_steps), - "prepare_batch_wall_ms": float(prepare_batch_wall_ms), - "decode_batch_wall_ms": float(decode_batch_wall_ms), - "request_total_ms": float(request_total_ms), - "total_semantic_len": int(total_semantic_len), - "finish_reason_counts": finish_reason_counts, - } + return build_scheduler_debug_batch_profile( + request_count=request_count, + max_steps=max_steps, + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + request_total_ms=request_total_ms, + finished_items=finished_items, + ) def _normalize_lang(self, value: str | None) -> str | None: - if value in [None, ""]: - return value - return str(value).lower() + return normalize_lang(value) @staticmethod def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: - totals: Dict[str, float] = {} - for item in items: - for key, value in item.items(): - if isinstance(value, (int, float)): - totals[key] = totals.get(key, 0.0) + float(value) - return totals + return aggregate_numeric_dicts(items) def _apply_default_reference(self, req: dict) -> dict: - normalized = dict(req) - default_ref = self.reference_registry.get_default() - if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: - normalized["ref_audio_path"] = default_ref.ref_audio_path - if "text_lang" in normalized: - normalized["text_lang"] = self._normalize_lang(normalized.get("text_lang")) - if "prompt_lang" in normalized: - normalized["prompt_lang"] = self._normalize_lang(normalized.get("prompt_lang")) - return normalized + return apply_default_reference(self.reference_registry, req) def check_params(self, req: dict) -> Optional[str]: - text = req.get("text", "") - text_lang = req.get("text_lang", "") - ref_audio_path = req.get("ref_audio_path", "") - media_type = req.get("media_type", "wav") - prompt_lang = req.get("prompt_lang", "") - text_split_method = req.get("text_split_method", "cut5") - - if ref_audio_path in [None, ""]: - return "ref_audio_path is required" - if text in [None, ""]: - return "text is required" - if text_lang in [None, ""]: - return "text_lang is required" - if text_lang.lower() not in self.tts.configs.languages: - return f"text_lang: {text_lang} is not supported in version {self.tts.configs.version}" - if prompt_lang in [None, ""]: - return "prompt_lang is required" - if prompt_lang.lower() not in self.tts.configs.languages: - return f"prompt_lang: {prompt_lang} is not supported in version {self.tts.configs.version}" - if media_type not in ["wav", "raw", "ogg", "aac"]: - return f"media_type: {media_type} is not supported" - if text_split_method not in self.cut_method_names: - return f"text_split_method:{text_split_method} is not supported" - return None + return check_params(self.tts, self.cut_method_names, req) @staticmethod def _base_request_defaults() -> Dict[str, Any]: - return { - "request_id": None, - "text": None, - "text_lang": None, - "ref_audio_path": None, - "aux_ref_audio_paths": None, - "prompt_text": "", - "prompt_lang": None, - "top_k": 15, - "top_p": 1.0, - "temperature": 1.0, - "text_split_method": "cut5", - "batch_size": 1, - "batch_threshold": 0.75, - "speed_factor": 1.0, - "split_bucket": False, - "fragment_interval": 0.3, - "seed": -1, - "media_type": "wav", - "streaming_mode": False, - "return_fragment": False, - "fixed_length_chunk": False, - "response_streaming": False, - "parallel_infer": False, - "repetition_penalty": 1.35, - "sample_steps": 32, - "super_sampling": False, - "overlap_length": 2, - "min_chunk_length": 16, - "early_stop_num": -1, - "ready_step": 0, - "timeout_sec": None, - } + return base_request_defaults() def _normalize_engine_request( self, @@ -595,111 +340,26 @@ class EngineApiFacade: normalize_streaming: bool = False, error_prefix: str = "request 参数非法: ", ) -> NormalizedEngineRequest: - if isinstance(payload, NormalizedEngineRequest): - normalized_payload = payload.to_payload() - else: - normalized_payload = self._base_request_defaults() - normalized_payload.update(dict(payload)) - if request_id not in [None, ""]: - normalized_payload["request_id"] = str(request_id) - elif normalized_payload.get("request_id") in [None, ""]: - raise ValueError("request_id is required after normalization") - normalized_payload = self._apply_default_reference(normalized_payload) - if normalize_streaming: - normalized_payload = self._normalize_streaming_mode(normalized_payload) - error = self.check_params(normalized_payload) - if error is not None: - raise ValueError(f"{error_prefix}{error}") - timeout_sec = normalized_payload.get("timeout_sec") - if timeout_sec in [None, ""]: - parsed_timeout = None - else: - parsed_timeout = float(timeout_sec) - aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") - if aux_ref_audio_paths in [None, "", []]: - normalized_aux_ref_audio_paths = None - else: - normalized_aux_ref_audio_paths = [str(item) for item in aux_ref_audio_paths] - return NormalizedEngineRequest( - request_id=str(normalized_payload["request_id"]), - text=str(normalized_payload["text"]), - text_lang=str(normalized_payload["text_lang"]), - ref_audio_path=str(normalized_payload["ref_audio_path"]), - prompt_lang=str(normalized_payload["prompt_lang"]), - prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), - aux_ref_audio_paths=normalized_aux_ref_audio_paths, - top_k=int(normalized_payload["top_k"]), - top_p=float(normalized_payload["top_p"]), - temperature=float(normalized_payload["temperature"]), - repetition_penalty=float(normalized_payload["repetition_penalty"]), - early_stop_num=int(normalized_payload.get("early_stop_num", -1)), - ready_step=int(normalized_payload.get("ready_step", 0)), - text_split_method=str(normalized_payload["text_split_method"]), - batch_size=int(normalized_payload["batch_size"]), - batch_threshold=float(normalized_payload["batch_threshold"]), - split_bucket=bool(normalized_payload["split_bucket"]), - speed_factor=float(normalized_payload["speed_factor"]), - fragment_interval=float(normalized_payload["fragment_interval"]), - seed=int(normalized_payload["seed"]), - media_type=str(normalized_payload["media_type"]), - streaming_mode=normalized_payload["streaming_mode"], - return_fragment=bool(normalized_payload.get("return_fragment", False)), - fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), - response_streaming=bool(normalized_payload.get("response_streaming", False)), - parallel_infer=bool(normalized_payload["parallel_infer"]), - sample_steps=int(normalized_payload["sample_steps"]), - super_sampling=bool(normalized_payload["super_sampling"]), - overlap_length=int(normalized_payload["overlap_length"]), - min_chunk_length=int(normalized_payload["min_chunk_length"]), - timeout_sec=parsed_timeout, + return normalize_engine_request( + tts=self.tts, + cut_method_names=self.cut_method_names, + reference_registry=self.reference_registry, + payload=payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, ) @staticmethod def _normalize_streaming_mode(req: dict) -> dict: - normalized = dict(req) - streaming_mode = normalized.get("streaming_mode", False) - return_fragment = normalized.get("return_fragment", False) - if streaming_mode is False: - normalized["streaming_mode"] = False - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 0: - normalized["streaming_mode"] = False - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 1 or streaming_mode is True: - normalized["streaming_mode"] = False - normalized["return_fragment"] = True - normalized["fixed_length_chunk"] = False - elif streaming_mode == 2: - normalized["streaming_mode"] = True - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = False - elif streaming_mode == 3: - normalized["streaming_mode"] = True - normalized["return_fragment"] = False - normalized["fixed_length_chunk"] = True - else: - raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") - normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) - return normalized + return normalize_streaming_mode(req) @staticmethod def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: - return aux_ref_audio_paths not in [None, [], ()] + return is_aux_ref_enabled(aux_ref_audio_paths) def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: - if normalized.response_streaming: - if normalized.return_fragment or normalized.fixed_length_chunk: - return "legacy_direct_fragment", "fragment_streaming_mode" - return "legacy_direct_streaming", "streaming_mode" - if self._is_aux_ref_enabled(normalized.aux_ref_audio_paths): - return "legacy_direct_aux_ref", "aux_ref_audio_paths" - if normalized.super_sampling: - return "legacy_direct_super_sampling", "super_sampling" - if normalized.prompt_text in [None, ""]: - return "legacy_direct_missing_prompt", "missing_prompt_text" - return "scheduler_v1_direct", None + return select_direct_backend(normalized) def _iter_legacy_direct_tts_bytes( self, @@ -708,88 +368,17 @@ class EngineApiFacade: backend: str, fallback_reason: str | None, ) -> Generator[bytes, None, None]: - payload = normalized.to_payload() - media_type = normalized.media_type - request_id = normalized.request_id - request_start = time.perf_counter() - chunk_count = 0 - stream_total_bytes = 0 - first_chunk_ms: float | None = None - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - try: - with self.direct_tts_lock: - tts_generator = self.tts.run(payload) - first_chunk = True - current_media_type = media_type - for sr, chunk in tts_generator: - if first_chunk: - first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) - self._update_request_state( - request_id, - EngineStatus.STREAMING, - { - "backend": backend, - "backend_mode": backend, - "fallback_reason": fallback_reason, - "sample_rate": int(sr), - }, - ) - if first_chunk and media_type == "wav": - header = wave_header_chunk(sample_rate=sr) - chunk_count += 1 - stream_total_bytes += len(header) - yield header - current_media_type = "raw" - first_chunk = False - elif first_chunk: - first_chunk = False - packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() - chunk_count += 1 - stream_total_bytes += len(packed_chunk) - yield packed_chunk - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - self._complete_request_state( - request_id, - dict( - self._build_legacy_direct_profile( - backend=backend, - fallback_reason=fallback_reason, - request_start=request_start, - finished_at=time.perf_counter(), - audio_bytes=stream_total_bytes, - chunk_count=chunk_count, - stream_total_bytes=stream_total_bytes, - first_chunk_ms=first_chunk_ms, - ), - streaming_completed=True, - ), + yield from self.direct_flow._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, ) def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: - if isinstance(req, NormalizedEngineRequest): - normalized = req - else: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - ) - backend, _ = self._select_direct_backend(normalized) - return backend == "scheduler_v1_direct" + return self.direct_flow._should_use_scheduler_backend_for_direct(req) def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: - payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized - return self.tts.text_preprocessor.pre_seg_text( - str(payload["text"]), - str(payload["text_lang"]), - str(payload.get("text_split_method", "cut5")), - ) + return self.direct_flow._segment_direct_text(normalized) def _build_segment_request( self, @@ -798,154 +387,14 @@ class EngineApiFacade: request_id: str, text: str, ) -> NormalizedEngineRequest: - payload = normalized.to_payload() - payload["request_id"] = request_id - payload["text"] = text - payload["streaming_mode"] = False - payload["return_fragment"] = False - payload["fixed_length_chunk"] = False - payload["response_streaming"] = False - return self._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") + return self.direct_flow._build_segment_request( + normalized, + request_id=request_id, + text=text, + ) async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: - request_start = time.perf_counter() - request_id = normalized.request_id - media_type = normalized.media_type - segment_texts = self._segment_direct_text(normalized) - if not segment_texts: - raise ValueError("text preprocessing returned no valid segments") - self._update_request_state( - request_id, - EngineStatus.CPU_PREPARING, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, - ) - segment_specs: List[SchedulerRequestSpec] = [] - for segment_index, segment_text in enumerate(segment_texts): - segment_request = self._build_segment_request( - normalized, - request_id=f"{request_id}_seg_{segment_index:03d}", - text=segment_text, - ) - segment_specs.append(self._build_scheduler_submit_spec(segment_request)) - - prepared_items = await asyncio.gather( - *[ - self._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=time.perf_counter(), - engine_request_id=None, - ) - for spec in segment_specs - ] - ) - prepare_profiles: List[Dict[str, Any]] = [] - loop = asyncio.get_running_loop() - done_futures: List[asyncio.Future] = [] - self._update_request_state( - request_id, - EngineStatus.READY_FOR_PREFILL, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, - ) - for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): - prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) - prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_profiles.append( - { - "request_id": spec.request_id, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile": dict(state.prepare_profile), - } - ) - done_future = loop.create_future() - done_futures.append(done_future) - await self._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=None, - timeout_sec=normalized.timeout_sec, - ) - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, - ) - timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) - jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) - for profile_item, job in zip(prepare_profiles, jobs): - profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) - profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) - self._merge_request_state_profile( - request_id, - { - "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), - "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), - "prepare_aggregate": self._aggregate_numeric_dicts( - [item["prepare_profile"] for item in prepare_profiles] - ), - }, - ) - - sample_rate: int | None = None - audio_parts: List[np.ndarray] = [] - worker_profiles: List[Dict[str, Any]] = [] - fragment_interval = float(normalized.fragment_interval) - silence_chunk: Optional[np.ndarray] = None - for job in jobs: - if job.error is not None: - raise RuntimeError(job.error) - if job.audio_data is None or job.sample_rate is None or job.result is None: - raise RuntimeError(f"{job.request_id} finished without audio result") - if sample_rate is None: - sample_rate = int(job.sample_rate) - silence_samples = int(fragment_interval * float(sample_rate)) - if silence_samples > 0: - silence_chunk = np.zeros(silence_samples, dtype=np.int16) - elif int(job.sample_rate) != sample_rate: - raise RuntimeError("segment sample rate mismatch") - audio_parts.append(job.audio_data) - if silence_chunk is not None: - audio_parts.append(silence_chunk.copy()) - worker_profiles.append(dict(job.result)) - if sample_rate is None or not audio_parts: - raise RuntimeError("direct scheduler backend produced no audio") - self._update_request_state( - request_id, - EngineStatus.FINALIZING, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, - ) - merged_audio = np.concatenate(audio_parts, axis=0) - pack_start = time.perf_counter() - audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() - pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) - direct_profile = self._build_direct_scheduler_profile( - backend="scheduler_v1_direct", - request_start=request_start, - response_ready_at=time.perf_counter(), - audio_bytes=len(audio_bytes), - sample_rate=int(sample_rate), - segment_texts=segment_texts, - prepare_profiles=prepare_profiles, - worker_profiles=worker_profiles, - pack_ms=pack_ms, - response_overhead_ms=0.0, - ) - self._complete_request_state( - request_id, - dict(direct_profile, streaming_completed=False), - ) - return DirectTTSExecution( - media_type=media_type, - streaming=False, - audio_bytes=audio_bytes, - request_id=request_id, - ) + return await self.direct_flow._run_direct_tts_via_scheduler(normalized) def _run_legacy_direct_tts_blocking( self, @@ -954,50 +403,10 @@ class EngineApiFacade: backend: str, fallback_reason: str | None, ) -> DirectTTSExecution: - normalized_payload = normalized.to_payload() - request_id = normalized.request_id - media_type = normalized.media_type - request_start = time.perf_counter() - self._update_request_state( - request_id, - EngineStatus.ACTIVE_DECODE, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - with self.direct_tts_lock: - tts_generator = self.tts.run(normalized_payload) - try: - sr, audio_data = next(tts_generator) - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - self._update_request_state( - request_id, - EngineStatus.FINALIZING, - {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, - ) - pack_start = time.perf_counter() - packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) - self._complete_request_state( - request_id, - dict( - self._build_legacy_direct_profile( - backend=backend, - fallback_reason=fallback_reason, - request_start=request_start, - finished_at=time.perf_counter(), - sample_rate=int(sr), - audio_bytes=len(packed_audio), - pack_ms=pack_ms, - ), - streaming_completed=False, - ), - ) - return DirectTTSExecution( - media_type=media_type, - streaming=False, - audio_bytes=packed_audio, - request_id=request_id, + return self.direct_flow._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, ) async def _run_direct_tts_via_legacy_backend( @@ -1007,393 +416,34 @@ class EngineApiFacade: backend: str, fallback_reason: str | None, ) -> DirectTTSExecution: - if normalized.response_streaming: - return DirectTTSExecution( - media_type=normalized.media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ), - request_id=normalized.request_id, - ) - return await asyncio.to_thread( - self._run_legacy_direct_tts_blocking, + return await self.direct_flow._run_direct_tts_via_legacy_backend( normalized, backend=backend, fallback_reason=fallback_reason, ) async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - error_prefix="", - ) - request_id = normalized.request_id - media_type = normalized.media_type - backend, fallback_reason = self._select_direct_backend(normalized) - self._register_request_state( - request_id=request_id, - api_mode="tts", - backend=backend, - media_type=media_type, - response_streaming=bool(normalized.response_streaming), - deadline_ts=( - time.perf_counter() + float(normalized.timeout_sec) - if normalized.timeout_sec is not None - else None - ), - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state( - request_id, - EngineStatus.VALIDATED, - { - "request_source": "direct_tts", - "selected_backend": backend, - "fallback_reason": fallback_reason, - }, - ) - if backend == "scheduler_v1_direct": - try: - return await self._run_direct_tts_via_scheduler(normalized) - except Exception as exc: - self._fail_request_state(request_id, str(exc)) - raise - return await self._run_direct_tts_via_legacy_backend( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) + return await self.direct_flow.run_direct_tts_async(req) def run_direct_tts(self, req: dict) -> DirectTTSExecution: - normalized = self._normalize_engine_request( - req, - request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), - normalize_streaming=True, - error_prefix="", - ) - request_id = normalized.request_id - media_type = normalized.media_type - backend, fallback_reason = self._select_direct_backend(normalized) - if not self._has_active_request(request_id): - self._register_request_state( - request_id=request_id, - api_mode="tts", - backend=backend, - media_type=media_type, - response_streaming=bool(normalized.response_streaming), - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state( - request_id, - EngineStatus.VALIDATED, - { - "request_source": "direct_tts", - "selected_backend": backend, - "fallback_reason": fallback_reason, - }, - ) - if backend != "scheduler_v1_direct": - if normalized.response_streaming: - return DirectTTSExecution( - media_type=media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ), - request_id=request_id, - ) - return self._run_legacy_direct_tts_blocking( - normalized, - backend=backend, - fallback_reason=fallback_reason, - ) - if normalized.response_streaming: - return DirectTTSExecution( - media_type=media_type, - streaming=True, - audio_generator=self._iter_legacy_direct_tts_bytes( - normalized, - backend="legacy_direct_sync_compat", - fallback_reason="sync_direct_compat", - ), - request_id=request_id, - ) - return self._run_legacy_direct_tts_blocking( - normalized, - backend="legacy_direct_sync_compat", - fallback_reason="sync_direct_compat", - ) + return self.direct_flow.run_direct_tts(req) def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: - specs: List[SchedulerRequestSpec] = [] - for index, payload in enumerate(request_items): - normalized = self._normalize_engine_request( - payload, - request_id=str(payload.get("request_id") or f"req_{index:03d}"), - error_prefix=f"request[{index}] 参数非法: ", - ) - specs.append(normalized.to_scheduler_spec()) - return specs + return self.scheduler_flow._build_scheduler_request_specs(request_items) def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: - normalized = self._normalize_engine_request( - payload, - request_id=( - payload.request_id - if isinstance(payload, NormalizedEngineRequest) - else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") - ), - ) - return normalized.to_scheduler_spec() + return self.scheduler_flow._build_scheduler_submit_spec(payload) @staticmethod def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: - return [ - { - "request_id": state.request_id, - "ready_step": int(state.ready_step), - "ref_audio_path": str(state.ref_audio_path), - "prompt_semantic_len": int(state.prompt_semantic.shape[0]), - "all_phone_len": int(state.all_phones.shape[0]), - "bert_len": int(state.all_bert_features.shape[-1]), - "norm_text": state.norm_text, - } - for state in states - ] + return EngineApiSchedulerFlow._summarize_scheduler_states(states) @staticmethod def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: - return [ - { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - } - for item in items - ] + return EngineApiSchedulerFlow._summarize_scheduler_finished(items) async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: - request_start = time.perf_counter() - set_scheduler_seed(seed) - specs = self._build_scheduler_request_specs(request_items) - request_ids = [spec.request_id for spec in specs] - for spec in specs: - self._register_request_state( - request_id=spec.request_id, - api_mode="scheduler_debug", - backend="scheduler_debug", - media_type="wav", - response_streaming=False, - meta={ - "text_len": len(spec.text), - "prompt_text_len": len(spec.prompt_text), - "text_lang": spec.text_lang, - "prompt_lang": spec.prompt_lang, - "ref_audio_path": str(spec.ref_audio_path), - "ready_step": int(spec.ready_step), - }, - ) - self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) - self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) - prepare_started_at = time.perf_counter() - try: - states = await self.scheduler_worker.prepare_states_batch_async(specs) - except Exception as exc: - for request_id in request_ids: - self._fail_request_state(request_id, str(exc)) - raise - prepare_finished_at = time.perf_counter() - prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) - for state in states: - self._update_request_state( - state.request_id, - EngineStatus.ACTIVE_DECODE, - { - "prepare_profile": dict(state.prepare_profile), - "norm_text": state.norm_text, - "norm_prompt_text": state.norm_prompt_text, - }, - ) - decode_started_at = time.perf_counter() - try: - finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) - except Exception as exc: - for request_id in request_ids: - self._fail_request_state(request_id, str(exc)) - raise - decode_finished_at = time.perf_counter() - decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) - request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) - finished_map = {item.request_id: item for item in finished} - request_profiles: List[Dict[str, Any]] = [] - for state in states: - item = finished_map.get(state.request_id) - if item is None: - self._fail_request_state(state.request_id, "scheduler_debug finished without result") - continue - request_profile = self._build_scheduler_debug_request_profile( - state=state, - item=item, - batch_request_count=len(states), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - batch_request_total_ms=request_total_ms, - ) - request_profiles.append( - { - "request_id": state.request_id, - "profile": dict(request_profile), - } - ) - self._complete_request_state( - state.request_id, - dict(request_profile), - ) - return SchedulerDebugExecution( - payload={ - "message": "success", - "request_count": len(states), - "max_steps": int(max_steps), - "batch_profile": self._build_scheduler_debug_batch_profile( - request_count=len(states), - max_steps=int(max_steps), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - request_total_ms=request_total_ms, - finished_items=finished, - ), - "requests": self._summarize_scheduler_states(states), - "finished": self._summarize_scheduler_finished(finished), - "request_profiles": request_profiles, - "request_traces": self._collect_request_summaries(request_ids), - } - ) + return await self.scheduler_flow.run_scheduler_debug(request_items, max_steps, seed) async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: - request_start = time.perf_counter() - prepare_start = request_start - normalized = self._normalize_engine_request( - payload, - request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), - ) - spec = self._build_scheduler_submit_spec(normalized) - deadline_ts = None - timeout_sec = normalized.timeout_sec - if timeout_sec is not None: - try: - deadline_ts = request_start + float(timeout_sec) - except Exception: - deadline_ts = None - self._register_request_state( - request_id=spec.request_id, - api_mode="scheduler_submit", - backend="scheduler_v1", - media_type=normalized.media_type, - response_streaming=False, - deadline_ts=deadline_ts, - meta=self._build_request_meta(normalized.to_payload()), - ) - self._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) - spec_ready_at = time.perf_counter() - prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) - self._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = await self._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=spec_ready_at, - engine_request_id=spec.request_id, - ) - except Exception as exc: - self._fail_request_state(spec.request_id, str(exc)) - raise - prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) - prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) - prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) - prepare_profile = dict(state.prepare_profile) - prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) - self._update_request_state( - spec.request_id, - EngineStatus.READY_FOR_PREFILL, - { - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile": prepare_profile, - }, - ) - api_after_prepare_start = time.perf_counter() - loop = asyncio.get_running_loop() - done_future = loop.create_future() - await self._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=normalized.media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=spec.request_id, - timeout_sec=normalized.timeout_sec, - ) - api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) - try: - job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) - except Exception as exc: - self._fail_request_state(spec.request_id, str(exc)) - raise - wait_return_at = time.perf_counter() - if job.error is not None: - raise RuntimeError(job.error) - if job.audio_data is None or job.sample_rate is None or job.result is None: - self._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") - raise RuntimeError(f"{job.request_id} finished without audio result") - pack_start = time.perf_counter() - audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_end = time.perf_counter() - pack_ms = (pack_end - pack_start) * 1000.0 - api_wait_result_ms = 0.0 - if job.result_ready_time is not None: - api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) - response_ready_at = time.perf_counter() - response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) - submit_profile = self._build_scheduler_submit_profile( - backend="scheduler_v1", - request_start=request_start, - response_ready_at=response_ready_at, - audio_bytes=len(audio_data), - sample_rate=int(job.sample_rate), - prepare_spec_build_ms=prepare_spec_build_ms, - prepare_wall_ms=prepare_wall_ms, - prepare_executor_queue_ms=prepare_executor_queue_ms, - prepare_executor_run_ms=prepare_executor_run_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - prepare_profile_wall_ms=prepare_profile_wall_ms, - prepare_other_ms=prepare_other_ms, - engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), - api_after_prepare_ms=api_after_prepare_ms, - api_wait_result_ms=api_wait_result_ms, - pack_ms=pack_ms, - response_overhead_ms=response_overhead_ms, - worker_profile=dict(job.result or {}), - ) - headers = self._build_scheduler_submit_headers( - request_id=job.request_id, - media_type=job.media_type, - sample_rate=int(job.sample_rate), - profile=submit_profile, - ) - self._merge_request_state_profile( - spec.request_id, - dict(submit_profile, response_headers_emitted=True), - ) - return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) + return await self.scheduler_flow.run_scheduler_submit(payload) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py new file mode 100644 index 00000000..b2a308df --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +import asyncio +import time +import uuid +from io import BytesIO +from typing import Any, Dict, Generator, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, wave_header_chunk +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineStatus, NormalizedEngineRequest, SchedulerPendingJob + + +class EngineApiDirectFlow: + def __init__(self, api: Any) -> None: + self.api = api + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + payload = normalized.to_payload() + media_type = normalized.media_type + request_id = normalized.request_id + request_start = time.perf_counter() + chunk_count = 0 + stream_total_bytes = 0 + first_chunk_ms: float | None = None + self.api._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + try: + with self.api.direct_tts_lock: + tts_generator = self.api.tts.run(payload) + first_chunk = True + current_media_type = media_type + for sr, chunk in tts_generator: + if first_chunk: + first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + self.api._update_request_state( + request_id, + EngineStatus.STREAMING, + { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "sample_rate": int(sr), + }, + ) + if first_chunk and media_type == "wav": + header = wave_header_chunk(sample_rate=sr) + chunk_count += 1 + stream_total_bytes += len(header) + yield header + current_media_type = "raw" + first_chunk = False + elif first_chunk: + first_chunk = False + packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_chunk) + yield packed_chunk + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + raise + self.api._complete_request_state( + request_id, + dict( + self.api._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + audio_bytes=stream_total_bytes, + chunk_count=chunk_count, + stream_total_bytes=stream_total_bytes, + first_chunk_ms=first_chunk_ms, + ), + streaming_completed=True, + ), + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + if isinstance(req, NormalizedEngineRequest): + normalized = req + else: + normalized = self.api._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + ) + backend, _ = self.api._select_direct_backend(normalized) + return backend == "scheduler_v1_direct" + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized + return self.api.tts.text_preprocessor.pre_seg_text( + str(payload["text"]), + str(payload["text_lang"]), + str(payload.get("text_split_method", "cut5")), + ) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + payload = normalized.to_payload() + payload["request_id"] = request_id + payload["text"] = text + payload["streaming_mode"] = False + payload["return_fragment"] = False + payload["fixed_length_chunk"] = False + payload["response_streaming"] = False + return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + request_start = time.perf_counter() + request_id = normalized.request_id + media_type = normalized.media_type + segment_texts = self._segment_direct_text(normalized) + if not segment_texts: + raise ValueError("text preprocessing returned no valid segments") + self.api._update_request_state( + request_id, + EngineStatus.CPU_PREPARING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, + ) + segment_specs = [] + for segment_index, segment_text in enumerate(segment_texts): + segment_request = self._build_segment_request( + normalized, + request_id=f"{request_id}_seg_{segment_index:03d}", + text=segment_text, + ) + segment_specs.append(self.api._build_scheduler_submit_spec(segment_request)) + + prepared_items = await asyncio.gather( + *[ + self.api._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=None, + ) + for spec in segment_specs + ] + ) + prepare_profiles: List[Dict[str, Any]] = [] + loop = asyncio.get_running_loop() + done_futures: List[asyncio.Future] = [] + self.api._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, + ) + for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profiles.append( + { + "request_id": spec.request_id, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": dict(state.prepare_profile), + } + ) + done_future = loop.create_future() + done_futures.append(done_future) + await self.api._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, + ) + self.api._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) + jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) + for profile_item, job in zip(prepare_profiles, jobs): + profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) + profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) + self.api._merge_request_state_profile( + request_id, + { + "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), + "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), + "prepare_aggregate": self.api._aggregate_numeric_dicts([item["prepare_profile"] for item in prepare_profiles]), + }, + ) + + sample_rate: int | None = None + audio_parts: List[np.ndarray] = [] + worker_profiles: List[Dict[str, Any]] = [] + fragment_interval = float(normalized.fragment_interval) + silence_chunk: Optional[np.ndarray] = None + for job in jobs: + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + if sample_rate is None: + sample_rate = int(job.sample_rate) + silence_samples = int(fragment_interval * float(sample_rate)) + if silence_samples > 0: + silence_chunk = np.zeros(silence_samples, dtype=np.int16) + elif int(job.sample_rate) != sample_rate: + raise RuntimeError("segment sample rate mismatch") + audio_parts.append(job.audio_data) + if silence_chunk is not None: + audio_parts.append(silence_chunk.copy()) + worker_profiles.append(dict(job.result)) + if sample_rate is None or not audio_parts: + raise RuntimeError("direct scheduler backend produced no audio") + self.api._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + merged_audio = np.concatenate(audio_parts, axis=0) + pack_start = time.perf_counter() + audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + direct_profile = self.api._build_direct_scheduler_profile( + backend="scheduler_v1_direct", + request_start=request_start, + response_ready_at=time.perf_counter(), + audio_bytes=len(audio_bytes), + sample_rate=int(sample_rate), + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=pack_ms, + response_overhead_ms=0.0, + ) + self.api._complete_request_state( + request_id, + dict(direct_profile, streaming_completed=False), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=audio_bytes, + request_id=request_id, + ) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + normalized_payload = normalized.to_payload() + request_id = normalized.request_id + media_type = normalized.media_type + request_start = time.perf_counter() + self.api._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + with self.api.direct_tts_lock: + tts_generator = self.api.tts.run(normalized_payload) + try: + sr, audio_data = next(tts_generator) + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + raise + self.api._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + pack_start = time.perf_counter() + packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + self.api._complete_request_state( + request_id, + dict( + self.api._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + sample_rate=int(sr), + audio_bytes=len(packed_audio), + pack_ms=pack_ms, + ), + streaming_completed=False, + ), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=packed_audio, + request_id=request_id, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + if normalized.response_streaming: + return DirectTTSExecution( + media_type=normalized.media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=normalized.request_id, + ) + return await asyncio.to_thread( + self._run_legacy_direct_tts_blocking, + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + normalized = self.api._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self.api._select_direct_backend(normalized) + self.api._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + deadline_ts=(time.perf_counter() + float(normalized.timeout_sec) if normalized.timeout_sec is not None else None), + meta=self.api._build_request_meta(normalized.to_payload()), + ) + self.api._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend == "scheduler_v1_direct": + try: + return await self._run_direct_tts_via_scheduler(normalized) + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + raise + return await self._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + normalized = self.api._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self.api._select_direct_backend(normalized) + if not self.api._has_active_request(request_id): + self.api._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + meta=self.api._build_request_meta(normalized.to_payload()), + ) + self.api._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend != "scheduler_v1_direct": + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py new file mode 100644 index 00000000..f950c68d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Sequence + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState + + +def build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + text = payload.get("text") + prompt_text = payload.get("prompt_text") + return { + "text_len": 0 if text is None else len(str(text)), + "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), + "text_lang": payload.get("text_lang"), + "prompt_lang": payload.get("prompt_lang"), + "ref_audio_path": payload.get("ref_audio_path"), + } + + +def sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + total = 0.0 + for item in items: + value = item.get(key, 0.0) + if isinstance(value, (int, float)): + total += float(value) + return total + + +def aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + totals: Dict[str, float] = {} + for item in items: + for key, value in item.items(): + if isinstance(value, (int, float)): + totals[key] = totals.get(key, 0.0) + float(value) + return totals + + +def build_direct_segment_trace( + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], +) -> List[Dict[str, Any]]: + results: List[Dict[str, Any]] = [] + for index, segment_text in enumerate(segment_texts): + prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} + worker_item = worker_profiles[index] if index < len(worker_profiles) else {} + prepare_profile = dict(prepare_item.get("prepare_profile", {})) + results.append( + { + "segment_index": index, + "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), + "text_len": len(str(segment_text)), + "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), + "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), + "prepare_engine_gpu_queue_wait_ms": float( + dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) + ), + "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), + "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), + "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), + "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), + "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), + "merge_ms": float(worker_item.get("merge_ms", 0.0)), + "decode_ms": float(worker_item.get("decode_ms", 0.0)), + "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), + "synth_ms": float(worker_item.get("synth_ms", 0.0)), + "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), + "decode_steps": int(worker_item.get("decode_steps", 0)), + "semantic_len": int(worker_item.get("semantic_len", 0)), + "finish_reason": worker_item.get("finish_reason"), + "norm_text": prepare_profile.get("norm_text"), + } + ) + return results + + +def build_direct_scheduler_profile( + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + pack_ms: float, + response_overhead_ms: float, +) -> Dict[str, Any]: + segment_trace = build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + prepare_wall_ms = sum_profile_field(prepare_profiles, "prepare_wall_ms") + prepare_profile_total_ms = sum_profile_field(prepare_profiles, "prepare_profile_total_ms") + engine_policy_wait_ms = sum_profile_field(prepare_profiles, "engine_policy_wait_ms") + engine_dispatch_wait_ms = sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") + decode_admission_wait_ms = sum_profile_field(worker_profiles, "decode_admission_wait_ms") + queue_wait_ms = sum_profile_field(worker_profiles, "queue_wait_ms") + prefill_ms = sum_profile_field(worker_profiles, "prefill_ms") + merge_ms = sum_profile_field(worker_profiles, "merge_ms") + decode_ms = sum_profile_field(worker_profiles, "decode_ms") + finalize_wait_ms = sum_profile_field(worker_profiles, "finalize_wait_ms") + synth_ms = sum_profile_field(worker_profiles, "synth_ms") + worker_total_ms = sum_profile_field(worker_profiles, "worker_total_ms") + decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) + semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, + ) + return { + "backend": backend, + "backend_mode": backend, + "segment_count": len(segment_texts), + "sample_rate": int(sample_rate), + "audio_bytes": int(audio_bytes), + "request_total_ms": request_total_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "engine_policy_wait_ms": engine_policy_wait_ms, + "engine_dispatch_wait_ms": engine_dispatch_wait_ms, + "decode_admission_wait_ms": decode_admission_wait_ms, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": prefill_ms, + "merge_ms": merge_ms, + "decode_ms": decode_ms, + "finalize_wait_ms": finalize_wait_ms, + "synth_ms": synth_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "worker_total_ms": worker_total_ms, + "request_other_ms": request_other_ms, + "decode_steps": decode_steps, + "semantic_len": semantic_len, + "prepare_segments": list(prepare_profiles), + "worker_segments": list(worker_profiles), + "segment_trace": segment_trace, + "prepare_aggregate": aggregate_numeric_dicts(prepare_profile_dicts), + } + + +def build_legacy_direct_profile( + *, + backend: str, + fallback_reason: str | None, + request_start: float, + finished_at: float, + sample_rate: int | None = None, + audio_bytes: int = 0, + pack_ms: float = 0.0, + chunk_count: int = 0, + stream_total_bytes: int = 0, + first_chunk_ms: float | None = None, +) -> Dict[str, Any]: + request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) + legacy_infer_ms = max(0.0, request_total_ms - pack_ms) + return { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "request_total_ms": request_total_ms, + "prepare_ms": 0.0, + "queue_wait_ms": 0.0, + "prefill_ms": 0.0, + "merge_ms": 0.0, + "decode_ms": 0.0, + "finalize_wait_ms": 0.0, + "synth_ms": 0.0, + "pack_ms": pack_ms, + "worker_total_ms": legacy_infer_ms, + "request_other_ms": 0.0, + "legacy_infer_ms": legacy_infer_ms, + "sample_rate": int(sample_rate) if sample_rate is not None else None, + "audio_bytes": int(audio_bytes), + "chunk_count": int(chunk_count), + "stream_total_bytes": int(stream_total_bytes), + "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), + } + + +def build_scheduler_submit_profile( + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + prepare_spec_build_ms: float, + prepare_wall_ms: float, + prepare_executor_queue_ms: float, + prepare_executor_run_ms: float, + prepare_profile_total_ms: float, + prepare_profile_wall_ms: float, + prepare_other_ms: float, + engine_policy_wait_ms: float, + api_after_prepare_ms: float, + api_wait_result_ms: float, + pack_ms: float, + response_overhead_ms: float, + worker_profile: Dict[str, Any], +) -> Dict[str, Any]: + worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms + - prepare_wall_ms + - engine_policy_wait_ms + - api_after_prepare_ms + - worker_total_ms + - api_wait_result_ms + - pack_ms, + ) + result = { + "backend": backend, + "backend_mode": backend, + "audio_bytes": int(audio_bytes), + "sample_rate": int(sample_rate), + "prepare_spec_build_ms": prepare_spec_build_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_executor_queue_ms": prepare_executor_queue_ms, + "prepare_executor_run_ms": prepare_executor_run_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile_wall_ms": prepare_profile_wall_ms, + "prepare_other_ms": prepare_other_ms, + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "api_after_prepare_ms": api_after_prepare_ms, + "api_wait_result_ms": api_wait_result_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "request_total_ms": request_total_ms, + "request_other_ms": request_other_ms, + } + result.update({key: value for key, value in worker_profile.items()}) + return result + + +def format_ms_header(value: Any) -> str: + return f"{float(value):.3f}" + + +def build_scheduler_submit_headers( + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], +) -> Dict[str, str]: + prepare_profile = dict(profile.get("prepare_profile", {})) + headers = { + "X-Request-Id": request_id, + "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), + "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), + "X-Queue-Wait-Ms": format_ms_header(profile.get("queue_wait_ms", 0.0)), + "X-Decode-Admission-Wait-Ms": format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), + "X-Engine-Policy-Wait-Ms": format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), + "X-Engine-Dispatch-Wait-Ms": format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), + "X-Prepare-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Wall-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Spec-Build-Ms": format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), + "X-Prepare-Executor-Queue-Ms": format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), + "X-Prepare-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), + "X-Prepare-Executor-Run-Ms": format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), + "X-Prepare-Profile-Total-Ms": format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), + "X-Prepare-Profile-Wall-Ms": format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), + "X-Prepare-Other-Ms": format_ms_header(profile.get("prepare_other_ms", 0.0)), + "X-Api-After-Prepare-Ms": format_ms_header(profile.get("api_after_prepare_ms", 0.0)), + "X-Prefill-Ms": format_ms_header(profile.get("prefill_ms", 0.0)), + "X-Merge-Ms": format_ms_header(profile.get("merge_ms", 0.0)), + "X-Decode-Ms": format_ms_header(profile.get("decode_ms", 0.0)), + "X-Finalize-Wait-Ms": format_ms_header(profile.get("finalize_wait_ms", 0.0)), + "X-Synth-Ms": format_ms_header(profile.get("synth_ms", 0.0)), + "X-Worker-Residual-Ms": format_ms_header(profile.get("worker_residual_ms", 0.0)), + "X-Worker-Other-Ms": format_ms_header(profile.get("worker_other_ms", 0.0)), + "X-Pack-Ms": format_ms_header(profile.get("pack_ms", 0.0)), + "X-Worker-Total-Ms": format_ms_header(profile.get("worker_total_ms", 0.0)), + "X-Api-Wait-Result-Ms": format_ms_header(profile.get("api_wait_result_ms", 0.0)), + "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), + "X-Sample-Rate": str(int(sample_rate)), + "X-Response-Overhead-Ms": format_ms_header(profile.get("response_overhead_ms", 0.0)), + "X-Request-Other-Ms": format_ms_header(profile.get("request_other_ms", 0.0)), + "X-Request-Total-Ms": format_ms_header(profile.get("request_total_ms", 0.0)), + } + headers.update( + { + "X-Prepare-Prompt-Text-Ms": format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), + "X-Prepare-Target-Text-Ms": format_ms_header(prepare_profile.get("text_features_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), + "X-Prepare-Prompt-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), + "X-Prepare-Target-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), + "X-Prepare-Prompt-Bert-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Forward-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), + "X-Prepare-Target-Bert-Forward-Ms": format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), + "X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), + "X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), + "X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Wait-Ms": format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-CPU-Ms": format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Forward-Ms": format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), + "X-Prepare-Ref-Spec-Ms": format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), + "X-Prepare-Ref-Spec-Wait-Ms": format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), + "X-Prepare-Ref-Bundle-Ms": format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), + "X-Prepare-Tensorize-Ms": format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), + "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), + } + ) + return headers + + +def build_scheduler_debug_request_profile( + *, + state: T2SRequestState, + item: T2SFinishedItem, + batch_request_count: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + batch_request_total_ms: float, +) -> Dict[str, Any]: + prepare_profile = dict(state.prepare_profile) + prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) + return { + "backend": "scheduler_debug", + "backend_mode": "scheduler_debug", + "batch_request_count": int(batch_request_count), + "batch_prepare_wall_ms": float(prepare_batch_wall_ms), + "batch_decode_wall_ms": float(decode_batch_wall_ms), + "batch_request_total_ms": float(batch_request_total_ms), + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), + "prepare_profile": prepare_profile, + "decode_steps": int(item.finish_idx), + "finish_idx": int(item.finish_idx), + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_reason": item.finish_reason, + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + } + + +def build_scheduler_debug_batch_profile( + *, + request_count: int, + max_steps: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + request_total_ms: float, + finished_items: Sequence[T2SFinishedItem], +) -> Dict[str, Any]: + finish_reason_counts: Dict[str, int] = {} + total_semantic_len = 0 + for item in finished_items: + finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 + total_semantic_len += int(item.semantic_tokens.shape[0]) + return { + "request_count": int(request_count), + "max_steps": int(max_steps), + "prepare_batch_wall_ms": float(prepare_batch_wall_ms), + "decode_batch_wall_ms": float(decode_batch_wall_ms), + "request_total_ms": float(request_total_ms), + "total_semantic_len": int(total_semantic_len), + "finish_reason_counts": finish_reason_counts, + } diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py new file mode 100644 index 00000000..974b9612 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import NormalizedEngineRequest, ReferenceRegistry + + +def normalize_lang(value: str | None) -> str | None: + if value in [None, ""]: + return value + return str(value).lower() + + +def apply_default_reference(reference_registry: ReferenceRegistry, req: dict) -> dict: + normalized = dict(req) + default_ref = reference_registry.get_default() + if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: + normalized["ref_audio_path"] = default_ref.ref_audio_path + if "text_lang" in normalized: + normalized["text_lang"] = normalize_lang(normalized.get("text_lang")) + if "prompt_lang" in normalized: + normalized["prompt_lang"] = normalize_lang(normalized.get("prompt_lang")) + return normalized + + +def check_params(tts: TTS, cut_method_names: Sequence[str], req: dict) -> Optional[str]: + text = req.get("text", "") + text_lang = req.get("text_lang", "") + ref_audio_path = req.get("ref_audio_path", "") + media_type = req.get("media_type", "wav") + prompt_lang = req.get("prompt_lang", "") + text_split_method = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return "ref_audio_path is required" + if text in [None, ""]: + return "text is required" + if text_lang in [None, ""]: + return "text_lang is required" + if text_lang.lower() not in tts.configs.languages: + return f"text_lang: {text_lang} is not supported in version {tts.configs.version}" + if prompt_lang in [None, ""]: + return "prompt_lang is required" + if prompt_lang.lower() not in tts.configs.languages: + return f"prompt_lang: {prompt_lang} is not supported in version {tts.configs.version}" + if media_type not in ["wav", "raw", "ogg", "aac"]: + return f"media_type: {media_type} is not supported" + if text_split_method not in cut_method_names: + return f"text_split_method:{text_split_method} is not supported" + return None + + +def base_request_defaults() -> Dict[str, Any]: + return { + "request_id": None, + "text": None, + "text_lang": None, + "ref_audio_path": None, + "aux_ref_audio_paths": None, + "prompt_text": "", + "prompt_lang": None, + "top_k": 15, + "top_p": 1.0, + "temperature": 1.0, + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "return_fragment": False, + "fixed_length_chunk": False, + "response_streaming": False, + "parallel_infer": False, + "repetition_penalty": 1.35, + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + "early_stop_num": -1, + "ready_step": 0, + "timeout_sec": None, + } + + +def normalize_streaming_mode(req: dict) -> dict: + normalized = dict(req) + streaming_mode = normalized.get("streaming_mode", False) + return_fragment = normalized.get("return_fragment", False) + if streaming_mode is False: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 0: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 1 or streaming_mode is True: + normalized["streaming_mode"] = False + normalized["return_fragment"] = True + normalized["fixed_length_chunk"] = False + elif streaming_mode == 2: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 3: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = True + else: + raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") + normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) + return normalized + + +def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return aux_ref_audio_paths not in [None, [], ()] + + +def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + if normalized.response_streaming: + if normalized.return_fragment or normalized.fixed_length_chunk: + return "legacy_direct_fragment", "fragment_streaming_mode" + return "legacy_direct_streaming", "streaming_mode" + if is_aux_ref_enabled(normalized.aux_ref_audio_paths): + return "legacy_direct_aux_ref", "aux_ref_audio_paths" + if normalized.super_sampling: + return "legacy_direct_super_sampling", "super_sampling" + if normalized.prompt_text in [None, ""]: + return "legacy_direct_missing_prompt", "missing_prompt_text" + return "scheduler_v1_direct", None + + +def normalize_engine_request( + *, + tts: TTS, + cut_method_names: Sequence[str], + reference_registry: ReferenceRegistry, + payload: dict | NormalizedEngineRequest, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", +) -> NormalizedEngineRequest: + if isinstance(payload, NormalizedEngineRequest): + normalized_payload = payload.to_payload() + else: + normalized_payload = base_request_defaults() + normalized_payload.update(dict(payload)) + if request_id not in [None, ""]: + normalized_payload["request_id"] = str(request_id) + elif normalized_payload.get("request_id") in [None, ""]: + raise ValueError("request_id is required after normalization") + normalized_payload = apply_default_reference(reference_registry, normalized_payload) + if normalize_streaming: + normalized_payload = normalize_streaming_mode(normalized_payload) + error = check_params(tts, cut_method_names, normalized_payload) + if error is not None: + raise ValueError(f"{error_prefix}{error}") + timeout_sec = normalized_payload.get("timeout_sec") + parsed_timeout = None if timeout_sec in [None, ""] else float(timeout_sec) + aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") + normalized_aux_ref_audio_paths = None if aux_ref_audio_paths in [None, "", []] else [str(item) for item in aux_ref_audio_paths] + return NormalizedEngineRequest( + request_id=str(normalized_payload["request_id"]), + text=str(normalized_payload["text"]), + text_lang=str(normalized_payload["text_lang"]), + ref_audio_path=str(normalized_payload["ref_audio_path"]), + prompt_lang=str(normalized_payload["prompt_lang"]), + prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), + aux_ref_audio_paths=normalized_aux_ref_audio_paths, + top_k=int(normalized_payload["top_k"]), + top_p=float(normalized_payload["top_p"]), + temperature=float(normalized_payload["temperature"]), + repetition_penalty=float(normalized_payload["repetition_penalty"]), + early_stop_num=int(normalized_payload.get("early_stop_num", -1)), + ready_step=int(normalized_payload.get("ready_step", 0)), + text_split_method=str(normalized_payload["text_split_method"]), + batch_size=int(normalized_payload["batch_size"]), + batch_threshold=float(normalized_payload["batch_threshold"]), + split_bucket=bool(normalized_payload["split_bucket"]), + speed_factor=float(normalized_payload["speed_factor"]), + fragment_interval=float(normalized_payload["fragment_interval"]), + seed=int(normalized_payload["seed"]), + media_type=str(normalized_payload["media_type"]), + streaming_mode=normalized_payload["streaming_mode"], + return_fragment=bool(normalized_payload.get("return_fragment", False)), + fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), + response_streaming=bool(normalized_payload.get("response_streaming", False)), + parallel_infer=bool(normalized_payload["parallel_infer"]), + sample_steps=int(normalized_payload["sample_steps"]), + super_sampling=bool(normalized_payload["super_sampling"]), + overlap_length=int(normalized_payload["overlap_length"]), + min_chunk_length=int(normalized_payload["min_chunk_length"]), + timeout_sec=parsed_timeout, + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py new file mode 100644 index 00000000..646b5b45 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import asyncio +import time +import uuid +from io import BytesIO +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous +from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution + + +class EngineApiSchedulerFlow: + def __init__(self, api: Any) -> None: + self.api = api + + def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, payload in enumerate(request_items): + normalized = self.api._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"req_{index:03d}"), + error_prefix=f"request[{index}] 参数非法: ", + ) + specs.append(normalized.to_scheduler_spec()) + return specs + + def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + normalized = self.api._normalize_engine_request( + payload, + request_id=( + payload.request_id + if isinstance(payload, NormalizedEngineRequest) + else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") + ), + ) + return normalized.to_scheduler_spec() + + @staticmethod + def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + @staticmethod + def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + request_start = time.perf_counter() + set_scheduler_seed(seed) + specs = self._build_scheduler_request_specs(request_items) + request_ids = [spec.request_id for spec in specs] + for spec in specs: + self.api._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_debug", + backend="scheduler_debug", + media_type="wav", + response_streaming=False, + meta={ + "text_len": len(spec.text), + "prompt_text_len": len(spec.prompt_text), + "text_lang": spec.text_lang, + "prompt_lang": spec.prompt_lang, + "ref_audio_path": str(spec.ref_audio_path), + "ready_step": int(spec.ready_step), + }, + ) + self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) + self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) + prepare_started_at = time.perf_counter() + try: + states = await self.api.scheduler_worker.prepare_states_batch_async(specs) + except Exception as exc: + for request_id in request_ids: + self.api._fail_request_state(request_id, str(exc)) + raise + prepare_finished_at = time.perf_counter() + prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) + for state in states: + self.api._update_request_state( + state.request_id, + EngineStatus.ACTIVE_DECODE, + { + "prepare_profile": dict(state.prepare_profile), + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + }, + ) + decode_started_at = time.perf_counter() + try: + finished = run_scheduler_continuous(self.api.tts.t2s_model.model, states, max_steps=int(max_steps)) + except Exception as exc: + for request_id in request_ids: + self.api._fail_request_state(request_id, str(exc)) + raise + decode_finished_at = time.perf_counter() + decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) + request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) + finished_map = {item.request_id: item for item in finished} + request_profiles: List[Dict[str, Any]] = [] + for state in states: + item = finished_map.get(state.request_id) + if item is None: + self.api._fail_request_state(state.request_id, "scheduler_debug finished without result") + continue + request_profile = self.api._build_scheduler_debug_request_profile( + state=state, + item=item, + batch_request_count=len(states), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + batch_request_total_ms=request_total_ms, + ) + request_profiles.append( + { + "request_id": state.request_id, + "profile": dict(request_profile), + } + ) + self.api._complete_request_state( + state.request_id, + dict(request_profile), + ) + return SchedulerDebugExecution( + payload={ + "message": "success", + "request_count": len(states), + "max_steps": int(max_steps), + "batch_profile": self.api._build_scheduler_debug_batch_profile( + request_count=len(states), + max_steps=int(max_steps), + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + request_total_ms=request_total_ms, + finished_items=finished, + ), + "requests": self._summarize_scheduler_states(states), + "finished": self._summarize_scheduler_finished(finished), + "request_profiles": request_profiles, + "request_traces": self.api._collect_request_summaries(request_ids), + } + ) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + request_start = time.perf_counter() + prepare_start = request_start + normalized = self.api._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), + ) + spec = self._build_scheduler_submit_spec(normalized) + deadline_ts = None + timeout_sec = normalized.timeout_sec + if timeout_sec is not None: + try: + deadline_ts = request_start + float(timeout_sec) + except Exception: + deadline_ts = None + self.api._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_submit", + backend="scheduler_v1", + media_type=normalized.media_type, + response_streaming=False, + deadline_ts=deadline_ts, + meta=self.api._build_request_meta(normalized.to_payload()), + ) + self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=spec_ready_at, + engine_request_id=spec.request_id, + ) + except Exception as exc: + self.api._fail_request_state(spec.request_id, str(exc)) + raise + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) + prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) + prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile = dict(state.prepare_profile) + prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) + self.api._update_request_state( + spec.request_id, + EngineStatus.READY_FOR_PREFILL, + { + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": prepare_profile, + }, + ) + api_after_prepare_start = time.perf_counter() + loop = asyncio.get_running_loop() + done_future = loop.create_future() + await self.api._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=spec.request_id, + timeout_sec=normalized.timeout_sec, + ) + api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) + try: + job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) + except Exception as exc: + self.api._fail_request_state(spec.request_id, str(exc)) + raise + wait_return_at = time.perf_counter() + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + self.api._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") + raise RuntimeError(f"{job.request_id} finished without audio result") + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + submit_profile = self.api._build_scheduler_submit_profile( + backend="scheduler_v1", + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=len(audio_data), + sample_rate=int(job.sample_rate), + prepare_spec_build_ms=prepare_spec_build_ms, + prepare_wall_ms=prepare_wall_ms, + prepare_executor_queue_ms=prepare_executor_queue_ms, + prepare_executor_run_ms=prepare_executor_run_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + prepare_profile_wall_ms=prepare_profile_wall_ms, + prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), + api_after_prepare_ms=api_after_prepare_ms, + api_wait_result_ms=api_wait_result_ms, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + worker_profile=dict(job.result or {}), + ) + headers = self.api._build_scheduler_submit_headers( + request_id=job.request_id, + media_type=job.media_type, + sample_rate=int(job.sample_rate), + profile=submit_profile, + ) + self.api._merge_request_state_profile( + spec.request_id, + dict(submit_profile, response_headers_emitted=True), + ) + return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) From a3a5aad15707ac984d0dae58794f18c99eab77c1 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 20:49:41 +0800 Subject: [PATCH 17/24] Add unified engine components for TTS processing and state management Introduce new modules including unified_engine_component_models, unified_engine_component_policy, unified_engine_component_registry, unified_engine_component_runtime, unified_engine_worker_completion, and unified_engine_worker_decode. These additions enhance the TTS framework by providing structured models for request handling, engine policies, and worker execution, significantly improving the architecture and maintainability of the system. The new components support asynchronous operations and optimize overall performance through better state management and processing capabilities. --- .../unified_engine_component_models.py | 120 ++ .../unified_engine_component_policy.py | 335 ++++ .../unified_engine_component_registry.py | 381 +++++ .../unified_engine_component_runtime.py | 334 ++++ .../unified_engine_components.py | 1213 +------------- .../TTS_infer_pack/unified_engine_worker.py | 1471 +---------------- .../unified_engine_worker_completion.py | 198 +++ .../unified_engine_worker_decode.py | 430 +++++ .../unified_engine_worker_execution.py | 164 ++ .../unified_engine_worker_finalize.py | 234 +++ .../unified_engine_worker_prepare.py | 71 + .../unified_engine_worker_runtime.py | 170 ++ .../unified_engine_worker_submit.py | 256 +++ 13 files changed, 2772 insertions(+), 2605 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py new file mode 100644 index 00000000..2c0cc9ac --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Generator, List, Optional + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + request_id: Optional[str] = None + + +@dataclass +class NormalizedEngineRequest: + request_id: str + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + aux_ref_audio_paths: List[str] | None = None + top_k: int = 15 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = False + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool | int = False + return_fragment: bool = False + fixed_length_chunk: bool = False + response_streaming: bool = False + parallel_infer: bool = False + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + timeout_sec: float | None = None + + def to_payload(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "text": self.text, + "text_lang": self.text_lang, + "ref_audio_path": self.ref_audio_path, + "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, + "prompt_text": self.prompt_text, + "prompt_lang": self.prompt_lang, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "text_split_method": self.text_split_method, + "batch_size": self.batch_size, + "batch_threshold": self.batch_threshold, + "speed_factor": self.speed_factor, + "split_bucket": self.split_bucket, + "fragment_interval": self.fragment_interval, + "seed": self.seed, + "media_type": self.media_type, + "streaming_mode": self.streaming_mode, + "return_fragment": self.return_fragment, + "fixed_length_chunk": self.fixed_length_chunk, + "response_streaming": self.response_streaming, + "parallel_infer": self.parallel_infer, + "repetition_penalty": self.repetition_penalty, + "sample_steps": self.sample_steps, + "super_sampling": self.super_sampling, + "overlap_length": self.overlap_length, + "min_chunk_length": self.min_chunk_length, + "early_stop_num": self.early_stop_num, + "ready_step": self.ready_step, + "timeout_sec": self.timeout_sec, + } + + def to_scheduler_spec(self) -> SchedulerRequestSpec: + return SchedulerRequestSpec( + request_id=self.request_id, + ref_audio_path=Path(self.ref_audio_path), + prompt_text=self.prompt_text, + prompt_lang=self.prompt_lang, + text=self.text, + text_lang=self.text_lang, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + early_stop_num=self.early_stop_num, + ready_step=self.ready_step, + ) + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py new file mode 100644 index 00000000..b6c5ca4d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus + + +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage == "prepare": + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + return "prepare", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + return "prepare", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py new file mode 100644 index 00000000..111ca500 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Deque, Dict, Optional, Sequence + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +class EngineStatus: + NEW = "NEW" + QUEUED = "QUEUED" + VALIDATED = "VALIDATED" + CPU_PREPARING = "CPU_PREPARING" + GPU_PREPARING = "GPU_PREPARING" + READY_FOR_PREFILL = "READY_FOR_PREFILL" + ACTIVE_DECODE = "ACTIVE_DECODE" + READY_FOR_FINALIZE = "READY_FOR_FINALIZE" + FINALIZING = "FINALIZING" + STREAMING = "STREAMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class EngineRequestState: + request_id: str + api_mode: str + backend: str + media_type: str + response_streaming: bool + submit_ts: float + deadline_ts: float | None = None + status: str = EngineStatus.NEW + updated_ts: float = 0.0 + error: str | None = None + finish_reason: str | None = None + meta: Dict[str, Any] = field(default_factory=dict) + profile: Dict[str, Any] = field(default_factory=dict) + lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "api_mode": self.api_mode, + "backend": self.backend, + "media_type": self.media_type, + "response_streaming": self.response_streaming, + "status": self.status, + "submit_ts": self.submit_ts, + "updated_ts": self.updated_ts, + "deadline_ts": self.deadline_ts, + "error": self.error, + "finish_reason": self.finish_reason, + "meta": dict(self.meta), + "profile": dict(self.profile), + "lifecycle_timestamps": dict(self.lifecycle_timestamps), + } + + +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> list[Dict[str, Any]]: + requested = set(request_ids) + results: list[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + engine_request_id: str | None = None + + +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py new file mode 100644 index 00000000..db03a0c3 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, List, Optional, Sequence + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import SchedulerPendingJob + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int(active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max)), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + queue_wait_ms: float = 0.0 + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class RuntimeStateCallbacks: + update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None + complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None + fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py index 3a124f4e..ac1adac5 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py @@ -1,1150 +1,63 @@ -from __future__ import annotations - -import asyncio -import os -import threading -import time -import uuid -from collections import deque -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Deque, Dict, List, Optional, Sequence, Tuple, Union - -import numpy as np -import torch - -from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState - - -@dataclass -class RuntimeControlCallbacks: - restart: Callable[[], None] | None = None - exit: Callable[[], None] | None = None - - -@dataclass -class DefaultReferenceState: - ref_audio_path: str | None = None - updated_at: float = 0.0 - - -class ReferenceRegistry: - def __init__(self) -> None: - self._lock = threading.Lock() - self._state = DefaultReferenceState() - - def set_default(self, ref_audio_path: str) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) - return self._state - - def clear(self) -> DefaultReferenceState: - with self._lock: - self._state = DefaultReferenceState() - return self._state - - def get_default(self) -> DefaultReferenceState: - with self._lock: - return DefaultReferenceState( - ref_audio_path=self._state.ref_audio_path, - updated_at=self._state.updated_at, - ) - - -@dataclass -class ModelRegistryState: - t2s_weights_path: str - vits_weights_path: str - generation: int = 0 - t2s_generation: int = 0 - vits_generation: int = 0 - updated_at: float = field(default_factory=time.time) - - -class ModelRegistry: - def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: - self._lock = threading.Lock() - self._state = ModelRegistryState( - t2s_weights_path=str(t2s_weights_path), - vits_weights_path=str(vits_weights_path), - ) - - def snapshot(self) -> ModelRegistryState: - with self._lock: - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.t2s_weights_path = str(weights_path) - self._state.generation += 1 - self._state.t2s_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: - with self._lock: - self._state.vits_weights_path = str(weights_path) - self._state.generation += 1 - self._state.vits_generation += 1 - self._state.updated_at = time.time() - return ModelRegistryState( - t2s_weights_path=self._state.t2s_weights_path, - vits_weights_path=self._state.vits_weights_path, - generation=self._state.generation, - t2s_generation=self._state.t2s_generation, - vits_generation=self._state.vits_generation, - updated_at=self._state.updated_at, - ) - - -@dataclass -class DirectTTSExecution: - media_type: str - streaming: bool - audio_generator: Optional[Generator[bytes, None, None]] = None - audio_bytes: Optional[bytes] = None - request_id: Optional[str] = None - - -@dataclass -class NormalizedEngineRequest: - request_id: str - text: str - text_lang: str - ref_audio_path: str - prompt_lang: str - prompt_text: str = "" - aux_ref_audio_paths: List[str] | None = None - top_k: int = 15 - top_p: float = 1.0 - temperature: float = 1.0 - repetition_penalty: float = 1.35 - early_stop_num: int = -1 - ready_step: int = 0 - text_split_method: str = "cut5" - batch_size: int = 1 - batch_threshold: float = 0.75 - split_bucket: bool = False - speed_factor: float = 1.0 - fragment_interval: float = 0.3 - seed: int = -1 - media_type: str = "wav" - streaming_mode: bool | int = False - return_fragment: bool = False - fixed_length_chunk: bool = False - response_streaming: bool = False - parallel_infer: bool = False - sample_steps: int = 32 - super_sampling: bool = False - overlap_length: int = 2 - min_chunk_length: int = 16 - timeout_sec: float | None = None - - def to_payload(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "text": self.text, - "text_lang": self.text_lang, - "ref_audio_path": self.ref_audio_path, - "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, - "prompt_text": self.prompt_text, - "prompt_lang": self.prompt_lang, - "top_k": self.top_k, - "top_p": self.top_p, - "temperature": self.temperature, - "text_split_method": self.text_split_method, - "batch_size": self.batch_size, - "batch_threshold": self.batch_threshold, - "speed_factor": self.speed_factor, - "split_bucket": self.split_bucket, - "fragment_interval": self.fragment_interval, - "seed": self.seed, - "media_type": self.media_type, - "streaming_mode": self.streaming_mode, - "return_fragment": self.return_fragment, - "fixed_length_chunk": self.fixed_length_chunk, - "response_streaming": self.response_streaming, - "parallel_infer": self.parallel_infer, - "repetition_penalty": self.repetition_penalty, - "sample_steps": self.sample_steps, - "super_sampling": self.super_sampling, - "overlap_length": self.overlap_length, - "min_chunk_length": self.min_chunk_length, - "early_stop_num": self.early_stop_num, - "ready_step": self.ready_step, - "timeout_sec": self.timeout_sec, - } - - def to_scheduler_spec(self) -> SchedulerRequestSpec: - return SchedulerRequestSpec( - request_id=self.request_id, - ref_audio_path=Path(self.ref_audio_path), - prompt_text=self.prompt_text, - prompt_lang=self.prompt_lang, - text=self.text, - text_lang=self.text_lang, - top_k=self.top_k, - top_p=self.top_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - early_stop_num=self.early_stop_num, - ready_step=self.ready_step, - ) - - -@dataclass -class SchedulerDebugExecution: - payload: Dict[str, Any] - - -@dataclass -class SchedulerSubmitExecution: - audio_bytes: bytes - media_type: str - headers: Dict[str, str] - - -@dataclass -class EnginePolicyConfig: - enabled: bool = True - poll_wait_ms: float = 5.0 - decode_backlog_soft_max: int = 0 - finalize_pending_soft_max: int = 0 - prepare_inflight_soft_max: int = 0 - active_decode_soft_max: int = 0 - ready_for_prefill_soft_max: int = 0 - active_request_soft_max: int = 0 - - def to_dict(self) -> Dict[str, Any]: - return { - "enabled": bool(self.enabled), - "poll_wait_ms": float(self.poll_wait_ms), - "decode_backlog_soft_max": int(self.decode_backlog_soft_max), - "finalize_pending_soft_max": int(self.finalize_pending_soft_max), - "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), - "active_decode_soft_max": int(self.active_decode_soft_max), - "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), - "active_request_soft_max": int(self.active_request_soft_max), - } - - -@dataclass -class EngineArbiterConfig: - poll_wait_ms: float = 5.0 - decode_burst: int = 4 - prepare_aging_ms: float = 10.0 - finalize_aging_ms: float = 10.0 - - def to_dict(self) -> Dict[str, Any]: - return { - "poll_wait_ms": float(self.poll_wait_ms), - "decode_burst": int(self.decode_burst), - "prepare_aging_ms": float(self.prepare_aging_ms), - "finalize_aging_ms": float(self.finalize_aging_ms), - } - - -class EngineStatus: - NEW = "NEW" - QUEUED = "QUEUED" - VALIDATED = "VALIDATED" - CPU_PREPARING = "CPU_PREPARING" - GPU_PREPARING = "GPU_PREPARING" - READY_FOR_PREFILL = "READY_FOR_PREFILL" - ACTIVE_DECODE = "ACTIVE_DECODE" - READY_FOR_FINALIZE = "READY_FOR_FINALIZE" - FINALIZING = "FINALIZING" - STREAMING = "STREAMING" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - - -@dataclass -class EngineRequestState: - request_id: str - api_mode: str - backend: str - media_type: str - response_streaming: bool - submit_ts: float - deadline_ts: float | None = None - status: str = EngineStatus.NEW - updated_ts: float = 0.0 - error: str | None = None - finish_reason: str | None = None - meta: Dict[str, Any] = field(default_factory=dict) - profile: Dict[str, Any] = field(default_factory=dict) - lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) - - def to_summary(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "api_mode": self.api_mode, - "backend": self.backend, - "media_type": self.media_type, - "response_streaming": self.response_streaming, - "status": self.status, - "submit_ts": self.submit_ts, - "updated_ts": self.updated_ts, - "deadline_ts": self.deadline_ts, - "error": self.error, - "finish_reason": self.finish_reason, - "meta": dict(self.meta), - "profile": dict(self.profile), - "lifecycle_timestamps": dict(self.lifecycle_timestamps), - } - - -class EngineRequestRegistry: - def __init__(self, recent_limit: int) -> None: - self.lock = threading.Lock() - self.active_requests: Dict[str, EngineRequestState] = {} - self.recent_requests: Deque[EngineRequestState] = deque() - self.recent_limit = max(1, int(recent_limit)) - - def register( - self, - *, - request_id: str, - api_mode: str, - backend: str, - media_type: str, - response_streaming: bool, - deadline_ts: float | None = None, - meta: Optional[Dict[str, Any]] = None, - ) -> EngineRequestState: - now = time.perf_counter() - state = EngineRequestState( - request_id=request_id, - api_mode=api_mode, - backend=backend, - media_type=media_type, - response_streaming=bool(response_streaming), - submit_ts=now, - deadline_ts=deadline_ts, - updated_ts=now, - meta=dict(meta or {}), - lifecycle_timestamps={EngineStatus.NEW: now}, - ) - with self.lock: - self.active_requests[request_id] = state - return state - - def _move_to_recent_locked(self, state: EngineRequestState) -> None: - self.recent_requests.appendleft(state) - while len(self.recent_requests) > self.recent_limit: - self.recent_requests.pop() - - @staticmethod - def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: - if not extra: - return - payload = dict(extra) - backend = payload.pop("backend", None) - if backend is not None: - state.backend = str(backend) - finish_reason = payload.pop("finish_reason", None) - if finish_reason is not None: - state.finish_reason = str(finish_reason) - error = payload.pop("error", None) - if error is not None: - state.error = str(error) - state.profile.update(payload) - - def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - return - state.status = str(status) - state.updated_ts = now - state.lifecycle_timestamps[str(status)] = now - self._apply_state_extra(state, extra) - - def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - if not extra: - return - now = time.perf_counter() - with self.lock: - state = self.active_requests.get(request_id) - if state is None: - for recent_state in self.recent_requests: - if recent_state.request_id == request_id: - state = recent_state - break - if state is None: - return - state.updated_ts = now - self._apply_state_extra(state, extra) - - def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.COMPLETED - state.updated_ts = now - state.lifecycle_timestamps[EngineStatus.COMPLETED] = now - self._apply_state_extra(state, extra) - self._move_to_recent_locked(state) - - def fail(self, request_id: str, error: str) -> None: - now = time.perf_counter() - with self.lock: - state = self.active_requests.pop(request_id, None) - if state is None: - return - state.status = EngineStatus.FAILED - state.updated_ts = now - state.error = str(error) - state.lifecycle_timestamps[EngineStatus.FAILED] = now - self._move_to_recent_locked(state) - - def snapshot(self) -> Dict[str, Any]: - with self.lock: - active = [state.to_summary() for state in self.active_requests.values()] - recent = [state.to_summary() for state in list(self.recent_requests)] - recent_limit = self.recent_limit - active.sort(key=lambda item: item["submit_ts"]) - return { - "active_count": len(active), - "recent_count": len(recent), - "recent_limit": recent_limit, - "active_requests": active, - "recent_requests": recent, - } - - def collect_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: - requested = set(request_ids) - results: List[Dict[str, Any]] = [] - with self.lock: - for state in self.active_requests.values(): - if state.request_id in requested: - results.append(state.to_summary()) - existing_ids = {item["request_id"] for item in results} - for state in self.recent_requests: - if state.request_id in requested and state.request_id not in existing_ids: - results.append(state.to_summary()) - results.sort(key=lambda item: item["request_id"]) - return results - - def has_active(self, request_id: str) -> bool: - with self.lock: - return request_id in self.active_requests - - -@dataclass -class SchedulerPendingJob: - request_id: str - state: T2SRequestState - done_event: threading.Event - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - enqueue_time: float - speed_factor: float - sample_steps: int - media_type: str - admission_wait_ms: float = 0.0 - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - prepare_wall_ms: float = 0.0 - prepare_profile_total_ms: float = 0.0 - first_schedule_time: float | None = None - prefill_ms: float = 0.0 - merge_ms: float = 0.0 - decode_ms: float = 0.0 - finalize_wait_ms: float = 0.0 - synth_ms: float = 0.0 - pack_ms: float = 0.0 - decode_steps: int = 0 - result_ready_time: float | None = None - result: dict | None = None - sample_rate: int | None = None - audio_data: np.ndarray | None = None - error: str | None = None - engine_request_id: str | None = None - - -class SchedulerJobRegistry: - def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: - self._lock = lock - self._job_map: Dict[str, SchedulerPendingJob] = {} - self._total_submitted = 0 - self._total_finished = 0 - - def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: - with self._lock: - if keep_job: - self._job_map[job.request_id] = job - self._total_submitted += 1 - - def get(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.get(request_id) - - def pop(self, request_id: str) -> SchedulerPendingJob | None: - with self._lock: - return self._job_map.pop(request_id, None) - - def remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - - def mark_finished(self) -> None: - with self._lock: - self._total_finished += 1 - - def mark_finished_and_remove(self, request_id: str) -> None: - with self._lock: - self._job_map.pop(request_id, None) - self._total_finished += 1 - - def is_empty(self) -> bool: - with self._lock: - return not self._job_map - - def submitted_count(self) -> int: - with self._lock: - return int(self._total_submitted) - - def finished_count(self) -> int: - with self._lock: - return int(self._total_finished) - - def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: - with self._lock: - request_ids = list(self._job_map.keys()) - return { - "job_count": int(len(request_ids)), - "request_ids": request_ids[: max(0, int(max_request_ids))], - "total_submitted": int(self._total_submitted), - "total_finished": int(self._total_finished), - } - - -class EngineTaskQueueOwner: - def __init__(self, completion_key: str = "total_completed") -> None: - self.condition = threading.Condition() - self.queue: Deque[Any] = deque() - self.total_submitted = 0 - self.total_completed = 0 - self.peak_waiting = 0 - self.completion_key = str(completion_key) - - def enqueue(self, item: Any) -> None: - with self.condition: - self.queue.append(item) - self.total_submitted += 1 - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def enqueue_many(self, items: Sequence[Any]) -> None: - if not items: - return - with self.condition: - for item in items: - self.queue.append(item) - self.total_submitted += len(items) - self.peak_waiting = max(self.peak_waiting, len(self.queue)) - self.condition.notify_all() - - def pop_left(self) -> Any | None: - with self.condition: - if not self.queue: - return None - return self.queue.popleft() - - def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: - if count <= 0: - return - with self.condition: - self.total_completed += int(count) - if notify: - self.condition.notify_all() - - def has_items(self) -> bool: - with self.condition: - return bool(self.queue) - - def waiting_count(self) -> int: - with self.condition: - return int(len(self.queue)) - - def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - with self.condition: - waiting_items = list(self.queue)[: max(0, int(max_request_ids))] - snapshot = { - "waiting_count": int(len(self.queue)), - "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], - "peak_waiting": int(self.peak_waiting), - "total_submitted": int(self.total_submitted), - self.completion_key: int(self.total_completed), - } - if extra: - snapshot.update(dict(extra)) - return snapshot - - def peek_oldest_age_ms(self, timestamp_attr: str) -> float: - with self.condition: - if not self.queue: - return 0.0 - enqueue_time = float(getattr(self.queue[0], timestamp_attr)) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def is_drained(self) -> bool: - with self.condition: - return not self.queue and self.total_submitted == self.total_completed - - def take_finalize_batch( - self, - *, - finalize_mode: str, - batch_max_items: int, - batch_wait_s: float, - use_vocoder: bool, - ) -> List[SchedulerFinalizeTask]: - with self.condition: - if not self.queue: - return [] - selected_tasks = [self.queue.popleft()] - if finalize_mode == "sync" or use_vocoder: - return selected_tasks - if batch_max_items <= 1: - return selected_tasks - first_task = selected_tasks[0] - oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) - if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: - self.queue.appendleft(first_task) - return [] - while len(selected_tasks) < batch_max_items: - if not self.queue: - break - matched_index = None - for index, task in enumerate(self.queue): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is None: - break - selected_tasks.append(self.queue[matched_index]) - del self.queue[matched_index] - return selected_tasks - - -class EnginePolicyArbiterController: - def __init__( - self, - *, - policy_config: EnginePolicyConfig, - arbiter_config: EngineArbiterConfig, - snapshot_request_registry: Callable[[], Dict[str, Any]], - get_worker_state: Callable[[], Dict[str, Any]], - snapshot_prepare_state: Callable[[], Dict[str, Any]], - snapshot_finalize_state: Callable[[], Dict[str, Any]], - snapshot_dispatch_state: Callable[[], Dict[str, Any]], - snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], - snapshot_job_registry: Callable[[], Dict[str, Any]], - peek_queue_age_ms: Callable[[str], float], - merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], - ) -> None: - self.policy_config = policy_config - self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) - self.arbiter_config = arbiter_config - self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) - self.condition = threading.Condition() - self.state = EngineArbiterState( - decode_budget_remaining=int(self.arbiter_config.decode_burst), - last_observed_at=time.perf_counter(), - ) - self.snapshot_request_registry = snapshot_request_registry - self.get_worker_state = get_worker_state - self.snapshot_prepare_state = snapshot_prepare_state - self.snapshot_finalize_state = snapshot_finalize_state - self.snapshot_dispatch_state = snapshot_dispatch_state - self.snapshot_decode_runtime_state = snapshot_decode_runtime_state - self.snapshot_job_registry = snapshot_job_registry - self.peek_queue_age_ms = peek_queue_age_ms - self.merge_request_state_profile = merge_request_state_profile - - def snapshot_state(self) -> Dict[str, Any]: - with self.condition: - return { - "config": self.arbiter_config.to_dict(), - "total_ticks": int(self.state.total_ticks), - "total_idle_ticks": int(self.state.total_idle_ticks), - "total_prepare_dispatches": int(self.state.total_prepare_dispatches), - "total_decode_dispatches": int(self.state.total_decode_dispatches), - "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), - "total_finalize_dispatches": int(self.state.total_finalize_dispatches), - "decode_budget_remaining": int(self.state.decode_budget_remaining), - "last_stage": str(self.state.last_stage), - "last_reason": str(self.state.last_reason), - "last_policy_allowed": bool(self.state.last_policy_allowed), - "last_observed_at": float(self.state.last_observed_at), - } - - def notify(self) -> None: - with self.condition: - self.condition.notify_all() - - def wait(self) -> None: - with self.condition: - self.condition.wait(timeout=self.arbiter_poll_s) - - def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: - with self.condition: - self.state.total_ticks += 1 - if stage == "idle": - self.state.total_idle_ticks += 1 - elif stage == "prepare": - self.state.total_prepare_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "finalize": - self.state.total_finalize_dispatches += 1 - self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) - elif stage == "decode_dispatch": - self.state.total_decode_dispatches += 1 - elif stage == "decode_runtime": - self.state.total_decode_runtime_ticks += 1 - self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) - self.state.last_stage = str(stage) - self.state.last_reason = str(reason) - self.state.last_policy_allowed = bool(policy_allowed) - self.state.last_observed_at = time.perf_counter() - - def build_stage_counters( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - prepare_dispatcher_state = self.snapshot_prepare_state() - finalize_dispatcher_state = self.snapshot_finalize_state() - dispatcher_state = self.snapshot_dispatch_state() - active_requests = list(request_registry.get("active_requests", [])) - status_counts: Dict[str, int] = {} - for item in active_requests: - status = str(item.get("status", "UNKNOWN")) - status_counts[status] = status_counts.get(status, 0) + 1 - - worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) - worker_decode_active_size = int(worker_state.get("running_requests", 0)) - worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) - worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) - worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) - engine_decode_runtime_state = self.snapshot_decode_runtime_state() - engine_job_registry = self.snapshot_job_registry() - decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) - decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) - return { - "active_request_count": int(len(active_requests)), - "status_counts": status_counts, - "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), - "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), - "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), - "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), - "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), - "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), - "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), - "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), - "worker_pending_jobs": worker_pending_jobs, - "worker_decode_active_size": worker_decode_active_size, - "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), - "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), - "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, - "engine_decode_runtime_active_request_count": decode_runtime_active_size, - "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), - "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), - "worker_prepare_inflight": worker_prepare_inflight, - "worker_finalize_pending": worker_finalize_pending, - "worker_finalize_inflight": worker_finalize_inflight, - "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), - "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), - "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), - "decode_backlog": int( - decode_runtime_pending_jobs + decode_runtime_active_size - if bool(worker_state.get("engine_decode_control_enabled", False)) - else worker_pending_jobs + worker_decode_active_size - ), - } - - def build_policy_snapshot( - self, - request_registry: Dict[str, Any], - worker_state: Dict[str, Any], - ) -> Dict[str, Any]: - counters = self.build_stage_counters(request_registry, worker_state) - config = self.policy_config.to_dict() - blocked_reasons: List[Dict[str, Any]] = [] - finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) - limit_checks = [ - ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), - ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), - ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), - ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), - ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), - ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), - ] - if bool(config["enabled"]): - for name, value, limit in limit_checks: - if limit > 0 and int(value) >= int(limit): - blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) - return { - "enabled": bool(config["enabled"]), - "allowed": (not bool(config["enabled"])) or not blocked_reasons, - "blocked_reasons": blocked_reasons, - "config": config, - "metrics": { - "active_request_count": int(counters["active_request_count"]), - "queued_request_count": int(counters["queued_request_count"]), - "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), - "active_decode_request_count": int(counters["active_decode_request_count"]), - "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), - "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), - "decode_backlog": int(counters["decode_backlog"]), - "prepare_inflight": int(counters["worker_prepare_inflight"]), - "finalize_pending": int(finalize_pending_total), - "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), - "finalize_inflight": int(counters["worker_finalize_inflight"]), - }, - "observed_at": time.perf_counter(), - } - - async def wait_for_policy_admission( - self, - *, - request_id: str | None, - timeout_sec: float | None, - ) -> tuple[float, Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if not self.policy_config.enabled: - return 0.0, snapshot - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - while True: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - snapshot = self.build_policy_snapshot(request_registry, worker_state) - if snapshot["allowed"]: - wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) - if request_id not in [None, ""]: - self.merge_request_state_profile( - str(request_id), - { - "engine_policy_wait_ms": float(wait_ms), - "engine_policy_snapshot": snapshot, - }, - ) - return wait_ms, snapshot - now = time.perf_counter() - if deadline is not None and now >= deadline: - blocked_summary = ", ".join( - f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) - ) - raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") - await asyncio.sleep(self.policy_poll_s) - - def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: - request_registry = self.snapshot_request_registry() - worker_state = self.get_worker_state() - policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) - finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) - decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) - decode_runtime_state = self.snapshot_decode_runtime_state() - worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) - worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) - worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) - worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) - prepare_age_ms = float(self.peek_queue_age_ms("prepare")) - finalize_age_ms = float(self.peek_queue_age_ms("finalize")) - decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) - decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) - policy_allowed = bool(policy_snapshot.get("allowed", True)) - if ( - worker_decode_control_enabled - and worker_decode_has_work - and policy_allowed - and decode_budget_remaining > 0 - and (worker_running_requests > 0 or worker_pending_jobs > 0) - ): - return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state - if ( - worker_decode_control_enabled - and worker_pending_jobs > 0 - and policy_allowed - and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) - ): - return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state - if ( - decode_waiting > 0 - and policy_allowed - and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) - ): - return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state - if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): - return "finalize", "finalize_aging", policy_snapshot, worker_state - if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state - if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state - if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: - return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state - if decode_waiting > 0 and policy_allowed: - return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state - if finalize_waiting > 0: - return "finalize", "finalize_fallback", policy_snapshot, worker_state - if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state - return "idle", "no_pending_work", policy_snapshot, worker_state - - -class EngineDecodeRuntimeOwner: - def __init__( - self, - *, - get_decode_runtime_counters: Callable[[], Dict[str, int]], - get_micro_batch_wait_s: Callable[[], float], - ) -> None: - self.get_decode_runtime_counters = get_decode_runtime_counters - self.get_micro_batch_wait_s = get_micro_batch_wait_s - self.condition = threading.Condition() - self.pending_jobs: Deque[SchedulerPendingJob] = deque() - self.active_batch: T2SActiveBatch | None = None - self.state_lock = threading.Lock() - self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) - - @staticmethod - def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: - if active_batch is None: - return {} - decode_step_index_max = 0 - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: - decode_step_index_max = int(active_batch.step_indices.max().item()) - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": int(decode_step_index_max), - } - - def snapshot_pending_queue_state(self) -> Dict[str, Any]: - with self.condition: - return { - "pending_jobs": int(len(self.pending_jobs)), - "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], - } - - def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: - with self.condition: - self.pending_jobs.append(job) - self.condition.notify_all() - self.refresh_state("engine_decode_pending_enqueue") - - def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): - return [] - pending_jobs = list(self.pending_jobs) - self.pending_jobs.clear() - self.refresh_state("engine_decode_pending_dequeue") - return pending_jobs - - def pending_age_ms(self) -> float: - with self.condition: - if not self.pending_jobs: - return 0.0 - enqueue_time = float(self.pending_jobs[0].enqueue_time) - return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) - - def has_pending_jobs(self) -> bool: - with self.condition: - return bool(self.pending_jobs) - - def get_active_batch(self) -> T2SActiveBatch | None: - return self.active_batch - - def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: - self.active_batch = active_batch - - def active_batch_summary(self) -> Dict[str, Any]: - return self.summarize_active_batch(self.active_batch) - - def refresh_state(self, last_event: str) -> None: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) - self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] - self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) - self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) - self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) - self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) - self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) - self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) - self.state.last_event = str(last_event) - self.state.updated_at = float(time.perf_counter()) - - def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: - if not snapshot: - return - pending_state = self.snapshot_pending_queue_state() - with self.state_lock: - self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) - self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) - self.state.active_request_count = int(snapshot.get("active_request_count", 0)) - self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] - self.state.prefill_done = bool(snapshot.get("prefill_done", False)) - self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) - self.state.total_cycles = int(snapshot.get("total_cycles", 0)) - self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) - self.state.step_cycles = int(snapshot.get("step_cycles", 0)) - self.state.has_work = bool( - pending_state.get("pending_jobs", 0) - or snapshot.get("active_request_count", 0) - or snapshot.get("has_work", False) - ) - self.state.last_event = str(snapshot.get("last_event", "unknown")) - self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) - - def snapshot_state(self) -> Dict[str, Any]: - pending_state = self.snapshot_pending_queue_state() - active_batch_summary = self.active_batch_summary() - worker_decode_counters = self.get_decode_runtime_counters() - with self.state_lock: - return { - "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), - "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), - "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), - "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), - "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), - "decode_step_index_max": int( - active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max) - ), - "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), - "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), - "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), - "has_work": bool( - pending_state.get("pending_jobs", 0) - or active_batch_summary.get("request_count", self.state.active_request_count) - or self.state.has_work - ), - "last_event": str(self.state.last_event), - "updated_at": float(self.state.updated_at), - } - -@dataclass -class SchedulerFinalizeTask: - request_id: str - item: T2SFinishedItem - enqueued_time: float - - -@dataclass -class EngineDispatchTask: - request_id: str - state: T2SRequestState - speed_factor: float - sample_steps: int - media_type: str - prepare_wall_ms: float - prepare_profile_total_ms: float - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - timeout_sec: float | None - enqueue_time: float - worker_job: SchedulerPendingJob | None = None - engine_policy_wait_ms: float = 0.0 - engine_dispatch_wait_ms: float = 0.0 - engine_policy_snapshot: Dict[str, Any] | None = None - error: str | None = None - - -@dataclass -class EngineGpuPrepareTask: - request_id: str - cpu_stage: PreparedCpuStage - done_loop: asyncio.AbstractEventLoop | None - done_future: asyncio.Future | None - engine_request_id: str | None - enqueue_time: float - queue_wait_ms: float = 0.0 - error: str | None = None - - -@dataclass -class EngineFinalizeQueueState: - waiting_count: int - waiting_request_ids: List[str] - peak_waiting: int - total_submitted: int - total_completed: int - - -@dataclass -class EngineArbiterState: - total_ticks: int = 0 - total_idle_ticks: int = 0 - total_prepare_dispatches: int = 0 - total_decode_dispatches: int = 0 - total_decode_runtime_ticks: int = 0 - total_finalize_dispatches: int = 0 - decode_budget_remaining: int = 0 - last_stage: str = "idle" - last_reason: str = "init" - last_observed_at: float = 0.0 - last_policy_allowed: bool = True - - -@dataclass -class EngineDecodeRuntimeState: - pending_jobs: int = 0 - pending_request_ids: List[str] = field(default_factory=list) - active_request_count: int = 0 - active_request_ids: List[str] = field(default_factory=list) - prefill_done: bool = False - decode_step_index_max: int = 0 - total_cycles: int = 0 - prefill_cycles: int = 0 - step_cycles: int = 0 - has_work: bool = False - last_event: str = "init" - updated_at: float = 0.0 - - -@dataclass -class RuntimeStateCallbacks: - update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None - complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None - fail: Callable[[str, str], None] | None = None - decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None - - +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_models import ( + DirectTTSExecution, + NormalizedEngineRequest, + RuntimeControlCallbacks, + SchedulerDebugExecution, + SchedulerSubmitExecution, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_policy import ( + EngineArbiterConfig, + EngineArbiterState, + EnginePolicyArbiterController, + EnginePolicyConfig, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import ( + DefaultReferenceState, + EngineRequestRegistry, + EngineRequestState, + EngineStatus, + ModelRegistry, + ModelRegistryState, + ReferenceRegistry, + SchedulerJobRegistry, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_runtime import ( + EngineDecodeRuntimeOwner, + EngineDecodeRuntimeState, + EngineDispatchTask, + EngineFinalizeQueueState, + EngineGpuPrepareTask, + EngineTaskQueueOwner, + RuntimeStateCallbacks, + SchedulerFinalizeTask, +) + +__all__ = [ + "DefaultReferenceState", + "DirectTTSExecution", + "EngineArbiterConfig", + "EngineArbiterState", + "EngineDecodeRuntimeOwner", + "EngineDecodeRuntimeState", + "EngineDispatchTask", + "EngineFinalizeQueueState", + "EngineGpuPrepareTask", + "EnginePolicyArbiterController", + "EnginePolicyConfig", + "EngineRequestRegistry", + "EngineRequestState", + "EngineStatus", + "EngineTaskQueueOwner", + "ModelRegistry", + "ModelRegistryState", + "NormalizedEngineRequest", + "ReferenceRegistry", + "RuntimeControlCallbacks", + "RuntimeStateCallbacks", + "SchedulerDebugExecution", + "SchedulerFinalizeTask", + "SchedulerJobRegistry", + "SchedulerPendingJob", + "SchedulerSubmitExecution", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py index 04d9090f..934ccf52 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py @@ -1,905 +1,25 @@ from __future__ import annotations -import asyncio import os import threading -import time -from collections import deque -from typing import Any, Callable, Deque, Dict, List, Optional - -import numpy as np -import torch +from typing import Callable, List from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState, decode_one_step, merge_active_batches, run_prefill_active_batch, run_scheduler_continuous -from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry, SchedulerPendingJob - - -class WorkerPrepareExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - ) -> None: - self.coordinator = PrepareCoordinator(tts) - self.on_state_change = on_state_change - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def snapshot(self) -> Dict[str, int]: - return dict(self.coordinator.snapshot()) - - def get_max_inflight(self) -> int: - return int(self.coordinator.snapshot().get("max_inflight", 0)) - - def is_idle(self) -> bool: - return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - results = await asyncio.gather( - *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] - ) - return [state for state, _, _ in results] - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - try: - return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - finally: - self._notify_state_change() - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - try: - return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) - finally: - self._notify_state_change() - - -class WorkerFinalizeExecutor: - def __init__( - self, - tts: TTS, - on_state_change: Callable[[], None] | None = None, - external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, - ) -> None: - self.tts = tts - self.on_state_change = on_state_change - self.external_submit = external_submit - self.condition = threading.Condition() - self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() - self.pending_peak = 0 - self.inflight = 0 - self.inflight_peak = 0 - self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) - self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() - self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) - self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) - - def _notify_state_change(self) -> None: - if self.on_state_change is None: - return - try: - self.on_state_change() - except Exception: - pass - - def get_worker_count(self) -> int: - return int(self.worker_count) - - def get_batch_policy(self) -> Dict[str, Any]: - return { - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_s": float(self.batch_wait_s), - } - - def get_pending_count(self) -> int: - with self.condition: - return int(len(self.pending_tasks)) - - def snapshot(self) -> Dict[str, Any]: - with self.condition: - return { - "finalize_pending": int(len(self.pending_tasks)), - "finalize_pending_peak": int(self.pending_peak), - "finalize_inflight": int(self.inflight), - "finalize_inflight_peak": int(self.inflight_peak), - "finalize_workers": int(self.worker_count), - "finalize_mode": str(self.finalize_mode), - "finalize_batch_max_items": int(self.batch_max_items), - "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), - } - - def is_idle(self) -> bool: - with self.condition: - return self.inflight <= 0 and not self.pending_tasks - - def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - if self.external_submit is not None: - self.external_submit(tasks) - self._notify_state_change() - return - with self.condition: - for task in tasks: - self.pending_tasks.append(task) - self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) - self.condition.notify_all() - self._notify_state_change() - - def begin_execution(self, task_count: int) -> None: - if task_count <= 0: - return - with self.condition: - self.inflight += int(task_count) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self.condition.notify_all() - self._notify_state_change() - - def end_execution(self, task_count: int) -> None: - with self.condition: - self.inflight = max(0, self.inflight - int(task_count)) - self.condition.notify_all() - self._notify_state_change() - - def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: - with self.condition: - while not self.pending_tasks: - self.condition.wait() - selected_tasks = [self.pending_tasks.popleft()] - if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - batch_deadline = time.perf_counter() + self.batch_wait_s - while len(selected_tasks) < self.batch_max_items: - if not self.pending_tasks: - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - continue - first_task = selected_tasks[0] - matched_index = None - for index, task in enumerate(self.pending_tasks): - if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: - matched_index = index - break - if matched_index is not None: - selected_tasks.append(self.pending_tasks[matched_index]) - del self.pending_tasks[matched_index] - continue - remaining = batch_deadline - time.perf_counter() - if remaining <= 0: - break - self.condition.wait(timeout=remaining) - self.inflight += len(selected_tasks) - self.inflight_peak = max(self.inflight_peak, self.inflight) - self._notify_state_change() - return selected_tasks - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - audio_fragment = self.tts.synthesize_audio_request_local( - semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), - phones=job.state.phones.detach().clone().unsqueeze(0), - prompt_semantic=job.state.prompt_semantic.detach().clone(), - prompt_phones=job.state.prompt_phones.detach().clone(), - refer_spec=( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ), - raw_audio=job.state.raw_audio.detach().clone(), - raw_sr=int(job.state.raw_sr), - speed=float(job.speed_factor), - sample_steps=int(job.sample_steps), - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - return self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - - def _synthesize_finished_audio_batch( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> List[tuple[int, np.ndarray]]: - semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] - phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] - refer_specs = [] - speeds = [] - sample_steps_list = [] - for job, _ in jobs_and_items: - refer_specs.append( - ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) - ) - speeds.append(float(job.speed_factor)) - sample_steps_list.append(int(job.sample_steps)) - audio_fragments = self.tts.synthesize_audio_requests_local_batched( - semantic_tokens_list=semantic_tokens_list, - phones_list=phones_list, - refer_specs=refer_specs, - speeds=speeds, - sample_steps_list=sample_steps_list, - ) - output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] - results: List[tuple[int, np.ndarray]] = [] - for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): - results.append( - self.tts.audio_postprocess( - audio=[[audio_fragment]], - sr=int(output_sr), - batch_index_list=None, - speed_factor=float(job.speed_factor), - split_bucket=False, - fragment_interval=0.0, - super_sampling=False, - ) - ) - return results - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - if not jobs_and_items: - return 0.0, [] - self._sync_device() - synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] - else: - batch_results = self._synthesize_finished_audio_batch(jobs_and_items) - self._sync_device() - synth_ms = (time.perf_counter() - synth_start) * 1000.0 - return float(synth_ms), batch_results - - -class WorkerCompletionBridge: - def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def notify_done_future(self, job: SchedulerPendingJob) -> None: - if job.done_loop is None or job.done_future is None: - return - try: - job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) - except RuntimeError: - pass - - def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.complete is None: - return - self.runtime_callbacks.complete(request_id, extra) - - def runtime_fail(self, request_id: str | None, error: str) -> None: - if request_id is None or self.runtime_callbacks.fail is None: - return - self.runtime_callbacks.fail(request_id, error) - - @staticmethod - def build_completed_job_result( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - audio_data: np.ndarray, - finished_at: float | None = None, - ) -> Dict[str, Any]: - finished_at = float(time.perf_counter() if finished_at is None else finished_at) - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - worker_residual_ms = max( - 0.0, - worker_total_ms - - queue_wait_ms - - job.prefill_ms - - job.merge_ms - - job.decode_ms - - job.finalize_wait_ms - - job.synth_ms, - ) - worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - job.result_ready_time = finished_at - prepare_profile = dict(job.state.prepare_profile) - result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "decode_admission_wait_ms": float(job.admission_wait_ms), - "engine_policy_wait_ms": float(job.engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), - "prepare_ms": job.prepare_wall_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile_total_ms": job.prepare_profile_total_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "merge_ms": job.merge_ms, - "decode_ms": job.decode_ms, - "finalize_wait_ms": job.finalize_wait_ms, - "synth_ms": job.synth_ms, - "worker_residual_ms": worker_residual_ms, - "worker_other_ms": worker_other_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.result = result - return result - - @staticmethod - def build_runtime_complete_payload( - job: SchedulerPendingJob, - item: T2SFinishedItem, - *, - sample_rate: int, - ) -> Dict[str, Any]: - return { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "sample_rate": int(sample_rate), - "worker_profile": dict(job.result or {}), - } - - def complete_job( - self, - job: SchedulerPendingJob, - *, - runtime_request_id: str | None, - runtime_extra: Optional[Dict[str, Any]] = None, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_complete(runtime_request_id, runtime_extra) - - def fail_job( - self, - job: SchedulerPendingJob, - *, - error: str, - remove_job: Callable[[], None] | None = None, - on_job_finished: Callable[[], None] | None = None, - notify_waiters: Callable[[], None] | None = None, - ) -> None: - job.error = str(error) - job.done_event.set() - self.notify_done_future(job) - if remove_job is not None: - remove_job() - if on_job_finished is not None: - on_job_finished() - if notify_waiters is not None: - notify_waiters() - self.runtime_fail(job.engine_request_id, str(error)) - - def complete_finalize_task( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - job: SchedulerPendingJob, - item: T2SFinishedItem, - sample_rate: int, - audio_data: np.ndarray, - ) -> None: - runtime_extra: Optional[Dict[str, Any]] = None - with condition: - if job_registry.get(item.request_id) is not job: - return - self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) - runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) - self.complete_job( - job, - runtime_request_id=job.engine_request_id, - runtime_extra=runtime_extra, - on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), - notify_waiters=condition.notify_all, - ) - - def fail_jobs( - self, - *, - condition: threading.Condition, - job_registry: SchedulerJobRegistry, - request_ids: List[str], - error: str, - ) -> None: - if not request_ids: - return - with condition: - for request_id in request_ids: - job = job_registry.get(request_id) - if job is None: - continue - self.fail_job( - job, - error=error, - on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), - ) - condition.notify_all() - - -class WorkerDecodeExecutor: - def __init__(self, tts: TTS, max_steps: int) -> None: - self.tts = tts - self.max_steps = int(max_steps) - - def _sync_device(self) -> None: - try: - device_str = str(self.tts.configs.device) - if device_str.startswith("cuda") and torch.cuda.is_available(): - torch.cuda.synchronize(self.tts.configs.device) - elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - except Exception: - pass - - def execute_prefill_merge( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if not pending_jobs: - return { - "executed": False, - "active_batch": active_batch, - "pending_jobs": [], - "prefill_elapsed_s": 0.0, - "merge_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - admitted_finished: List[T2SFinishedItem] = [] - prefill_elapsed_s = 0.0 - merge_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - self._sync_device() - prefill_start = time.perf_counter() - mark_prefill_started(pending_jobs, prefill_start) - admitted_active_batch, admitted_finished = run_prefill_active_batch( - self.tts.t2s_model.model, - [job.state for job in pending_jobs], - max_steps=self.max_steps, - ) - self._sync_device() - prefill_elapsed_s = time.perf_counter() - prefill_start - if add_prefill_time is not None: - add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(admitted_finished) - merge_start = time.perf_counter() - active_batch = merge_active_batches( - self.tts.t2s_model.model, - active_batch, - admitted_active_batch, - ) - merge_elapsed_s = time.perf_counter() - merge_start - if add_merge_time is not None: - add_merge_time( - [] if active_batch is None else list(active_batch.request_ids), - merge_elapsed_s, - ) - except Exception as exc: - error = str(exc) - error_request_ids = [job.request_id for job in pending_jobs] - if finalize_error is not None: - finalize_error(error_request_ids, error) - return { - "executed": True, - "active_batch": active_batch, - "pending_jobs": list(pending_jobs), - "prefill_elapsed_s": float(prefill_elapsed_s), - "merge_elapsed_s": float(merge_elapsed_s), - "finished_items": list(admitted_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_step( - self, - *, - active_batch: Optional[T2SActiveBatch], - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - if active_batch is None: - return { - "executed": False, - "active_batch": None, - "request_ids": [], - "decode_elapsed_s": 0.0, - "finished_items": [], - "error": None, - "error_request_ids": [], - } - active_request_ids: List[str] = [] - step_finished: List[T2SFinishedItem] = [] - decode_elapsed_s = 0.0 - error: str | None = None - error_request_ids: List[str] = [] - try: - active_request_ids = [state.request_id for state in active_batch.states] - self._sync_device() - decode_start = time.perf_counter() - active_batch, step_finished = decode_one_step( - self.tts.t2s_model.model, - active_batch, - max_steps=self.max_steps, - ) - self._sync_device() - decode_elapsed_s = time.perf_counter() - decode_start - if add_decode_time is not None: - add_decode_time(active_request_ids, decode_elapsed_s) - if enqueue_finished is not None: - enqueue_finished(step_finished) - except Exception as exc: - error = str(exc) - error_request_ids = list(active_request_ids) - if finalize_error is not None: - finalize_error(error_request_ids, error) - active_batch = None - return { - "executed": True, - "active_batch": active_batch, - "request_ids": active_request_ids, - "decode_elapsed_s": float(decode_elapsed_s), - "finished_items": list(step_finished), - "error": error, - "error_request_ids": error_request_ids, - } - - def execute_decode_cycle( - self, - *, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], - add_prefill_time: Callable[[List[str], float], None] | None, - add_merge_time: Callable[[List[str], float], None] | None, - add_decode_time: Callable[[List[str], float], None] | None, - enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, - finalize_error: Callable[[List[str], str], None] | None, - ) -> Dict[str, Any]: - result = { - "executed": False, - "prefill_merge_executed": False, - "decode_step_executed": False, - "active_batch": active_batch, - "prefill_phase": {}, - "decode_phase": {}, - } - prefill_phase = self.execute_prefill_merge( - pending_jobs=list(pending_jobs), - active_batch=result["active_batch"], - mark_prefill_started=mark_prefill_started, - add_prefill_time=add_prefill_time, - add_merge_time=add_merge_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - prefill_executed = bool(prefill_phase.get("executed", False)) - result["prefill_phase"] = prefill_phase - result["active_batch"] = prefill_phase.get("active_batch") - if prefill_executed: - result["executed"] = True - result["prefill_merge_executed"] = True - decode_phase = self.execute_decode_step( - active_batch=result["active_batch"], - add_decode_time=add_decode_time, - enqueue_finished=enqueue_finished, - finalize_error=finalize_error, - ) - decode_executed = bool(decode_phase.get("executed", False)) - result["decode_phase"] = decode_phase - result["active_batch"] = decode_phase.get("active_batch") - if decode_executed: - result["executed"] = True - result["decode_step_executed"] = True - return result - - -class WorkerDecodeLegacyShell: - def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: - self.condition = condition - self.micro_batch_wait_s = float(micro_batch_wait_s) - self.pending_jobs: List[SchedulerPendingJob] = [] - self.active_batch: T2SActiveBatch | None = None - - @staticmethod - def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: - if active_batch is None: - return None - return { - "request_count": int(len(active_batch.request_ids)), - "request_ids": list(active_batch.request_ids), - "prefill_done": bool(active_batch.prefill_done), - "decode_step_index_max": ( - int(active_batch.step_indices.max().item()) - if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 - else 0 - ), - } - - def current_backlog_locked(self) -> int: - running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) - return int(len(self.pending_jobs) + running_requests) - - def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: - self.pending_jobs.append(job) - - def snapshot_locked(self) -> Dict[str, Any]: - active_batch_summary = self._summarize_active_batch(self.active_batch) - executor_local_pending_jobs = int(len(self.pending_jobs)) - executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) - executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) - return { - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "executor_local_active_batch": active_batch_summary, - } - - def is_idle_locked(self) -> bool: - return self.active_batch is None and not self.pending_jobs - - def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs and self.active_batch is None: - self.condition.wait(timeout=self.micro_batch_wait_s) - elif wait_for_batch and self.pending_jobs: - self.condition.wait(timeout=self.micro_batch_wait_s) - if not self.pending_jobs: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - with self.condition: - if not self.pending_jobs: - return [] - if wait_for_batch: - oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) - if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: - return [] - pending = list(self.pending_jobs) - self.pending_jobs.clear() - return pending - - def has_decode_runtime_work(self) -> bool: - with self.condition: - return bool(self.pending_jobs or self.active_batch is not None) - - def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: - active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) - decode_step_index_max = 0 - prefill_done = False - if self.active_batch is not None: - prefill_done = bool(self.active_batch.prefill_done) - if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: - decode_step_index_max = int(self.active_batch.step_indices.max().item()) - return { - "pending_jobs": int(len(self.pending_jobs)), - "active_request_count": int(len(active_request_ids)), - "active_request_ids": active_request_ids[:32], - "prefill_done": bool(prefill_done), - "decode_step_index_max": int(decode_step_index_max), - "total_cycles": int(total_cycles), - "prefill_cycles": int(prefill_cycles), - "step_cycles": int(step_cycles), - "has_work": bool(self.pending_jobs or self.active_batch is not None), - "last_event": str(last_event), - "updated_at": float(time.perf_counter()), - } - - def run_prefill_merge_once_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_prefill_merge(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_step_once_nonblocking( - self, - *, - external_active_batch: Optional[T2SActiveBatch], - execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], - ) -> Dict[str, Any]: - active_batch = self.active_batch if external_active_batch is None else external_active_batch - result = execute_decode_step(active_batch) - if external_active_batch is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - return result - - def run_decode_cycle_nonblocking( - self, - *, - external_pending_jobs: Optional[List[SchedulerPendingJob]], - external_active_batch: Optional[T2SActiveBatch], - execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], - on_cycle_executed: Callable[[Dict[str, Any]], None] | None, - ) -> Dict[str, Any]: - pending_jobs = ( - list(external_pending_jobs) - if external_pending_jobs is not None - else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) - ) - active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch - result = execute_decode_cycle(pending_jobs, active_batch) - if external_pending_jobs is None: - with self.condition: - self.active_batch = result.get("active_batch") - self.condition.notify_all() - if result.get("executed") and on_cycle_executed is not None: - on_cycle_executed(result) - return result - - def run_loop( - self, - *, - run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], - ) -> None: - while True: - executed = run_decode_cycle_nonblocking() - if executed.get("executed"): - continue - wait_for_batch = self.active_batch is None - pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) - if pending_jobs: - with self.condition: - self.pending_jobs = pending_jobs + self.pending_jobs - self.condition.notify_all() - continue - time.sleep(self.micro_batch_wait_s) - - -class WorkerDecodeRuntimeTracker: - def __init__( - self, - runtime_callbacks: RuntimeStateCallbacks | None = None, - ) -> None: - self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() - self.total_cycles = 0 - self.prefill_cycles = 0 - self.step_cycles = 0 - - def get_counters(self) -> Dict[str, int]: - return { - "total_cycles": int(self.total_cycles), - "prefill_cycles": int(self.prefill_cycles), - "step_cycles": int(self.step_cycles), - } - - def record_cycle(self, result: Dict[str, Any]) -> None: - if not bool(result.get("executed")): - return - self.total_cycles += 1 - if bool(result.get("prefill_merge_executed")): - self.prefill_cycles += 1 - if bool(result.get("decode_step_executed")): - self.step_cycles += 1 - - def build_runtime_summary_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> Dict[str, Any]: - return legacy_shell.build_runtime_summary_locked( - total_cycles=int(self.total_cycles), - prefill_cycles=int(self.prefill_cycles), - step_cycles=int(self.step_cycles), - last_event=str(last_event), - ) - - def notify_runtime_update_locked( - self, - *, - legacy_shell: WorkerDecodeLegacyShell, - last_event: str, - ) -> None: - if self.runtime_callbacks.decode_runtime_update is None: - return - snapshot = self.build_runtime_summary_locked( - legacy_shell=legacy_shell, - last_event=last_event, - ) - self.runtime_callbacks.decode_runtime_update(snapshot) - - -class UnifiedSchedulerWorker: +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_completion import WorkerCompletionBridge +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_decode import WorkerDecodeExecutor, WorkerDecodeLegacyShell, WorkerDecodeRuntimeTracker +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_execution import WorkerExecutionMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_finalize import WorkerFinalizeExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_prepare import WorkerPrepareExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_runtime import WorkerRuntimeBookkeepingMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_submit import WorkerSubmitLifecycleMixin + + +class UnifiedSchedulerWorker( + WorkerSubmitLifecycleMixin, + WorkerRuntimeBookkeepingMixin, + WorkerExecutionMixin, +): def __init__( self, tts: TTS, @@ -949,562 +69,3 @@ class UnifiedSchedulerWorker: def _notify_worker_state_change(self) -> None: with self.condition: self.condition.notify_all() - - def _current_decode_backlog_locked(self) -> int: - return self.decode_legacy_shell.current_backlog_locked() - - def get_micro_batch_wait_s(self) -> float: - return float(self.micro_batch_wait_s) - - def is_engine_decode_control_enabled(self) -> bool: - return bool(self.engine_decode_control_enabled) - - def get_prepare_max_inflight(self) -> int: - return int(self.prepare_executor.get_max_inflight()) - - def get_capacity_limits(self) -> Dict[str, int]: - return { - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - } - - def get_finalize_batch_policy(self) -> Dict[str, Any]: - return dict(self.finalize_executor.get_batch_policy()) - - def get_decode_runtime_counters(self) -> Dict[str, int]: - with self.condition: - return self.decode_runtime_tracker.get_counters() - - def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: - decode_backlog = self._current_decode_backlog_locked() - finalize_pending = int(self.finalize_executor.get_pending_count()) - prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) - blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max - blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max - return ( - not blocked_decode and not blocked_finalize, - { - "decode_backlog": decode_backlog, - "finalize_pending": finalize_pending, - "prepare_inflight": prepare_inflight, - "decode_backlog_max": int(self.decode_backlog_max), - "finalize_pending_max": int(self.finalize_pending_max), - }, - ) - - def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: - start = time.perf_counter() - deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) - last_snapshot: Dict[str, int] = {} - while True: - with self.condition: - allowed, snapshot = self._can_accept_submit_locked() - last_snapshot = snapshot - if allowed: - return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot - if deadline is not None and time.perf_counter() >= deadline: - raise TimeoutError( - "scheduler submit admission timeout " - f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" - ) - self.condition.wait(timeout=self.micro_batch_wait_s) - - def _admission_snapshot_locked(self) -> Dict[str, int]: - _, snapshot = self._can_accept_submit_locked() - return snapshot - - async def submit_async( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - return await asyncio.to_thread( - self.submit, - state, - speed_factor, - sample_steps, - media_type, - prepare_wall_ms, - prepare_profile_total_ms, - done_loop, - done_future, - engine_request_id, - timeout_sec, - skip_capacity_wait, - admission_wait_ms_override, - admission_snapshot_override, - engine_policy_wait_ms, - engine_dispatch_wait_ms, - enqueue_pending, - ) - - def snapshot(self) -> dict: - with self.condition: - prepare_state = self.prepare_executor.snapshot() - finalize_state = self.finalize_executor.snapshot() - shell_state = self.decode_legacy_shell.snapshot_locked() - decode_runtime_counters = self.decode_runtime_tracker.get_counters() - engine_owned_decode_state = bool(self.engine_decode_control_enabled) - active_batch_summary = shell_state.get("executor_local_active_batch") - executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) - executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) - executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) - return { - "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, - "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, - "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), - "legacy_state_owner_mode": not engine_owned_decode_state, - "decode_state_owner": "engine" if engine_owned_decode_state else "worker", - "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, - "executor_local_pending_jobs": executor_local_pending_jobs, - "executor_local_running_requests": executor_local_running_requests, - "executor_local_has_work": executor_local_has_work, - "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), - "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), - "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), - "prepare_inflight": prepare_state["inflight"], - "prepare_peak_inflight": prepare_state["peak_inflight"], - "prepare_max_inflight": prepare_state.get("max_inflight", 0), - "prepare_state": dict(prepare_state), - **finalize_state, - "decode_backlog_max": self.decode_backlog_max, - "finalize_pending_max": self.finalize_pending_max, - "active_batch": {} if engine_owned_decode_state else active_batch_summary, - "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, - "total_submitted": self.job_registry.submitted_count(), - "total_finished": self.job_registry.finished_count(), - "drained": self.is_drained(), - } - - def is_drained(self) -> bool: - with self.condition: - return ( - self.decode_legacy_shell.is_idle_locked() - and self.job_registry.is_empty() - and self.prepare_executor.is_idle() - and self.finalize_executor.is_idle() - ) - - def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: - deadline = time.perf_counter() + max(0.0, timeout_sec) - while time.perf_counter() < deadline: - if self.is_drained(): - return True - time.sleep(poll_interval_sec) - return self.is_drained() - - def submit( - self, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None = None, - done_future: asyncio.Future | None = None, - engine_request_id: str | None = None, - timeout_sec: float | None = None, - skip_capacity_wait: bool = False, - admission_wait_ms_override: float | None = None, - admission_snapshot_override: Dict[str, Any] | None = None, - engine_policy_wait_ms: float = 0.0, - engine_dispatch_wait_ms: float = 0.0, - enqueue_pending: bool = True, - ) -> SchedulerPendingJob: - if skip_capacity_wait: - with self.condition: - admission_snapshot = ( - dict(admission_snapshot_override) - if admission_snapshot_override is not None - else dict(self._admission_snapshot_locked()) - ) - admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) - else: - admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) - job = SchedulerPendingJob( - request_id=state.request_id, - state=state, - done_event=threading.Event(), - done_loop=done_loop, - done_future=done_future, - enqueue_time=time.perf_counter(), - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - admission_wait_ms=float(admission_wait_ms), - engine_policy_wait_ms=float(engine_policy_wait_ms), - engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - engine_request_id=engine_request_id or state.request_id, - ) - with self.condition: - self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) - if enqueue_pending: - self.decode_legacy_shell.enqueue_pending_job_locked(job) - self.condition.notify_all() - if enqueue_pending: - self._notify_decode_runtime_state("submit") - self._runtime_update( - job.engine_request_id, - EngineStatus.QUEUED, - { - "scheduler_request_id": job.request_id, - "decode_admission_wait_ms": float(admission_wait_ms), - "engine_policy_wait_ms": float(engine_policy_wait_ms), - "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), - "admission_snapshot": dict(admission_snapshot), - }, - ) - return job - - async def prepare_state_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) - - async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return await self.prepare_executor.prepare_states_batch_async(specs) - - async def prepare_cpu_stage_profiled_async( - self, - spec: SchedulerRequestSpec, - prepare_submit_at: float, - ) -> PreparedCpuStage: - return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - - async def prepare_gpu_stage_profiled_async( - self, - cpu_stage: PreparedCpuStage, - ) -> tuple[T2SRequestState, float, float]: - return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) - - def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: - with self.condition: - for job in pending_jobs: - job.first_schedule_time = float(started_at) - self._runtime_update( - job.engine_request_id, - EngineStatus.GPU_PREPARING, - {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, - ) - - def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.prefill_ms += delta_ms - - def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.merge_ms += delta_ms - - def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: - delta_ms = float(elapsed_s) * 1000.0 - if not request_ids: - return - activate_request_ids: List[str] = [] - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - if job.decode_steps == 0: - activate_request_ids.append(job.engine_request_id) - job.decode_ms += delta_ms - job.decode_steps += 1 - for engine_request_id in activate_request_ids: - self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) - - def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: - if not request_ids: - return - with self.condition: - for request_id in request_ids: - job = self.job_registry.get(request_id) - if job is not None: - job.finalize_wait_ms += float(delta_ms) - - def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: - if not items: - return - enqueued_at = time.perf_counter() - tasks: List[SchedulerFinalizeTask] = [] - with self.condition: - for item in items: - job = self.job_registry.get(item.request_id) - if job is not None: - self._runtime_update( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - }, - ) - tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) - self.finalize_executor.enqueue_tasks(tasks) - - def begin_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.begin_execution(task_count) - - def end_finalize_execution(self, task_count: int) -> None: - self.finalize_executor.end_execution(task_count) - - def record_external_job_done(self, request_id: str) -> None: - with self.condition: - self.job_registry.mark_finished_and_remove(request_id) - self.condition.notify_all() - - def synthesize_finalize_jobs( - self, - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], - ) -> tuple[float, List[tuple[int, np.ndarray]]]: - return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) - - def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: - self.completion_bridge.complete_finalize_task( - condition=self.condition, - job_registry=self.job_registry, - job=job, - item=item, - sample_rate=sample_rate, - audio_data=audio_data, - ) - - def _finalize_error(self, request_ids: List[str], error: str) -> None: - self.completion_bridge.fail_jobs( - condition=self.condition, - job_registry=self.job_registry, - request_ids=request_ids, - error=error, - ) - - @staticmethod - def _resolve_done_future(job: SchedulerPendingJob) -> None: - future = job.done_future - if future is None or future.done(): - return - future.set_result(job) - - def _notify_done_future(self, job: SchedulerPendingJob) -> None: - self.completion_bridge.notify_done_future(job) - - def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: - if request_id is None or self.runtime_callbacks.update is None: - return - self.runtime_callbacks.update(request_id, status, extra) - - def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: - self.completion_bridge.runtime_complete(request_id, extra) - - def _runtime_fail(self, request_id: str | None, error: str) -> None: - self.completion_bridge.runtime_fail(request_id, error) - - def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: - return self.decode_runtime_tracker.build_runtime_summary_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _notify_decode_runtime_state(self, last_event: str) -> None: - with self.condition: - self.decode_runtime_tracker.notify_runtime_update_locked( - legacy_shell=self.decode_legacy_shell, - last_event=str(last_event), - ) - - def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: - with self.condition: - self.decode_runtime_tracker.record_cycle(result) - - def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) - - def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: - return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) - - def has_decode_runtime_work(self) -> bool: - return self.decode_legacy_shell.has_decode_runtime_work() - - def execute_prefill_merge( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_prefill_merge( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_step( - self, - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - return self.decode_executor.execute_decode_step( - active_batch=active_batch, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - - def execute_decode_cycle( - self, - pending_jobs: List[SchedulerPendingJob], - active_batch: Optional[T2SActiveBatch], - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_executor.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=active_batch, - mark_prefill_started=self._mark_prefill_started, - add_prefill_time=None if external_bookkeeping else self._add_prefill_time, - add_merge_time=None if external_bookkeeping else self._add_merge_time, - add_decode_time=None if external_bookkeeping else self._add_decode_time, - enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, - finalize_error=None if external_bookkeeping else self._finalize_error, - ) - self._record_decode_runtime_cycle(result) - return result - - def run_prefill_merge_once_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("prefill_merge") - return result - - def run_decode_step_once_nonblocking( - self, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_step_once_nonblocking( - external_active_batch=external_active_batch, - execute_decode_step=lambda batch_state: self.execute_decode_step( - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - ) - if emit_runtime_state: - self._notify_decode_runtime_state("decode_step") - return result - - def run_decode_cycle_nonblocking( - self, - external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, - external_active_batch: Optional[T2SActiveBatch] = None, - emit_runtime_state: bool = True, - external_bookkeeping: bool = False, - ) -> Dict[str, Any]: - result = self.decode_legacy_shell.run_decode_cycle_nonblocking( - external_pending_jobs=external_pending_jobs, - external_active_batch=external_active_batch, - execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( - pending_jobs=batch_jobs, - active_batch=batch_state, - external_bookkeeping=external_bookkeeping, - ), - on_cycle_executed=None, - ) - if result.get("executed") and emit_runtime_state: - self._notify_decode_runtime_state("decode_cycle") - return result - - def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - with self.condition: - for task in tasks: - job = self.job_registry.get(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return - now = time.perf_counter() - for task in tasks: - self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) - for job, item in jobs_and_items: - self._runtime_update( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) - with self.condition: - for job, _ in jobs_and_items: - tracked_job = self.job_registry.get(job.request_id) - if tracked_job is not None: - tracked_job.synth_ms += synth_ms - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self._finalize_error([task.request_id for task in tasks], str(exc)) - finally: - self.finalize_executor.end_execution(len(tasks)) - - def _run_finalize_loop(self) -> None: - while True: - tasks = self.finalize_executor.take_task_batch_blocking() - self.execute_finalize_tasks(tasks) - - def _run_loop(self) -> None: - self.decode_legacy_shell.run_loop( - run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() - ) - - diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py new file mode 100644 index 00000000..da2c057a --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerJobRegistry, SchedulerPendingJob + + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.complete is None: + return + self.runtime_callbacks.complete(request_id, extra) + + def runtime_fail(self, request_id: str | None, error: str) -> None: + if request_id is None or self.runtime_callbacks.fail is None: + return + self.runtime_callbacks.fail(request_id, error) + + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py new file mode 100644 index 00000000..784f71d0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + T2SActiveBatch, + T2SFinishedItem, + decode_one_step, + merge_active_batches, + run_prefill_active_batch, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerPendingJob + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: + while True: + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: + with self.condition: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py new file mode 100644 index 00000000..465f7a2c --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerExecutionMixin: + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) + + def _run_loop(self) -> None: + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py new file mode 100644 index 00000000..4f5833fd --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import os +import threading +import time +from collections import deque +from typing import Any, Callable, Deque, Dict, List + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: + with self.condition: + return { + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), + } + + def is_idle(self) -> bool: + with self.condition: + return self.inflight <= 0 and not self.pending_tasks + + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ), + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py new file mode 100644 index 00000000..28da24ee --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Callable, Dict, List + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState + + +class WorkerPrepareExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) + + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) + + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py new file mode 100644 index 00000000..de12f5e1 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerRuntimeBookkeepingMixin: + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: + self._runtime_update( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + }, + ) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py new file mode 100644 index 00000000..1e67f8d3 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerPendingJob + + +class WorkerSubmitLifecycleMixin: + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) From d453a8e47c15fa861715e8a75c4d6b5f3f370ba1 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Wed, 11 Mar 2026 21:15:19 +0800 Subject: [PATCH 18/24] Add unified engine stage components for TTS processing orchestration Introduce new modules including EngineDecodeStageMixin, EngineDispatchStageMixin, EngineFinalizeStageMixin, EnginePrepareStageMixin, and EngineStageFutureMixin. These components enhance the TTS framework by providing structured methods for managing engine stages, including decoding, dispatching, finalizing, and preparing tasks. The new architecture supports improved state management and asynchronous operations, significantly enhancing the maintainability and performance of the TTS system. --- .../unified_engine_api_scheduler.py | 2 +- .../unified_engine_stage_decode.py | 40 +++ .../unified_engine_stage_dispatch.py | 93 ++++++ .../unified_engine_stage_executor.py | 314 +----------------- .../unified_engine_stage_finalize.py | 76 +++++ .../unified_engine_stage_futures.py | 59 ++++ .../unified_engine_stage_prepare.py | 67 ++++ 7 files changed, 349 insertions(+), 302 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py index 646b5b45..1e934f16 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py @@ -280,4 +280,4 @@ class EngineApiSchedulerFlow: spec.request_id, dict(submit_profile, response_headers_emitted=True), ) - return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) + return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=str(job.media_type), headers=headers) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py new file mode 100644 index 00000000..d3a7a8cf --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py @@ -0,0 +1,40 @@ +from __future__ import annotations + + +class EngineDecodeStageMixin: + def run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self.snapshot_engine_decode_runtime_state() + pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self.add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self.add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self.decode_runtime_owner.refresh_state("engine_decode_cycle") + return bool(result.get("executed", False)) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py new file mode 100644 index 00000000..53ebd793 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Dict + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask + + +class EngineDispatchStageMixin: + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.dispatch_queue_owner.enqueue(task) + self.notify_arbiter() + self.merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int( + self.snapshot_engine_dispatch_state()["waiting_count"] + ), + }, + ) + return task + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, object], worker_state: Dict[str, object]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self.register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self.decode_runtime_owner.enqueue_pending_job(worker_job) + self.notify_arbiter() + self.dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py index 77274056..01921d51 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -1,24 +1,30 @@ from __future__ import annotations -import asyncio -import time from typing import Any, Callable, Dict, List, Optional from GPT_SoVITS.TTS_infer_pack.TTS import TTS -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( EngineDecodeRuntimeOwner, - EngineDispatchTask, - EngineGpuPrepareTask, - EngineStatus, EngineTaskQueueOwner, SchedulerFinalizeTask, SchedulerPendingJob, ) +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_decode import EngineDecodeStageMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_dispatch import EngineDispatchStageMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_finalize import EngineFinalizeStageMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_futures import EngineStageFutureMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_prepare import EnginePrepareStageMixin from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker -class EngineStageExecutor: +class EngineStageExecutor( + EngineStageFutureMixin, + EnginePrepareStageMixin, + EngineFinalizeStageMixin, + EngineDispatchStageMixin, + EngineDecodeStageMixin, +): def __init__( self, *, @@ -62,297 +68,3 @@ class EngineStageExecutor: self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state self._notify_arbiter: Callable[[], None] | None = None - - def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None: - self._notify_arbiter = notify_arbiter - - def notify_arbiter(self) -> None: - if self._notify_arbiter is not None: - self._notify_arbiter() - - @staticmethod - def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: - if future.done(): - return - future.set_exception(error) - - def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - @staticmethod - def _resolve_prepare_future( - future: asyncio.Future, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if future.done(): - return - future.set_result(payload) - - def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) - except RuntimeError: - pass - - def _notify_prepare_result( - self, - task: EngineGpuPrepareTask, - payload: tuple[T2SRequestState, float, float], - ) -> None: - if task.done_loop is None or task.done_future is None: - return - try: - task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) - except RuntimeError: - pass - - async def prepare_state_via_engine_gpu_queue( - self, - *, - spec: Any, - prepare_submit_at: float, - engine_request_id: str | None, - ) -> tuple[T2SRequestState, float, float]: - cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) - if engine_request_id not in [None, ""]: - self.update_request_state( - str(engine_request_id), - EngineStatus.GPU_PREPARING, - { - "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), - "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), - "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), - "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), - }, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - task = EngineGpuPrepareTask( - request_id=spec.request_id, - cpu_stage=cpu_stage, - done_loop=loop, - done_future=done_future, - engine_request_id=engine_request_id or spec.request_id, - enqueue_time=time.perf_counter(), - ) - self.prepare_queue_owner.enqueue(task) - self.notify_arbiter() - return await done_future - - def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: - if not tasks: - return - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - self.update_request_state( - job.engine_request_id, - EngineStatus.READY_FOR_FINALIZE, - { - "finish_reason": task.item.finish_reason, - "semantic_len": int(task.item.semantic_tokens.shape[0]), - "finish_idx": int(task.item.finish_idx), - }, - ) - self.finalize_queue_owner.enqueue_many(tasks) - self.notify_arbiter() - - def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: - finalize_policy = self.scheduler_worker.get_finalize_batch_policy() - return self.finalize_queue_owner.take_finalize_batch( - finalize_mode=str(finalize_policy.get("finalize_mode", "async")), - batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), - batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), - use_vocoder=bool(self.tts.configs.use_vocoder), - ) - - async def enqueue_prepared_state_for_dispatch( - self, - *, - state: T2SRequestState, - speed_factor: float, - sample_steps: int, - media_type: str, - prepare_wall_ms: float, - prepare_profile_total_ms: float, - done_loop: asyncio.AbstractEventLoop | None, - done_future: asyncio.Future | None, - engine_request_id: str | None, - timeout_sec: float | None, - ) -> EngineDispatchTask: - task = EngineDispatchTask( - request_id=state.request_id, - state=state, - speed_factor=float(speed_factor), - sample_steps=int(sample_steps), - media_type=media_type, - prepare_wall_ms=float(prepare_wall_ms), - prepare_profile_total_ms=float(prepare_profile_total_ms), - done_loop=done_loop, - done_future=done_future, - engine_request_id=engine_request_id or state.request_id, - timeout_sec=timeout_sec, - enqueue_time=time.perf_counter(), - ) - self.dispatch_queue_owner.enqueue(task) - self.notify_arbiter() - self.merge_request_state_profile( - task.engine_request_id or task.request_id, - { - "engine_dispatch_queue_depth_on_enqueue": int( - self.snapshot_engine_dispatch_state()["waiting_count"] - ), - }, - ) - return task - - def run_engine_prepare_once(self) -> bool: - task = self.prepare_queue_owner.pop_left() - if task is None: - return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) - if task.engine_request_id not in [None, ""]: - self.merge_request_state_profile( - str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, - ) - self.prepare_queue_owner.mark_completed(1) - self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True - - def run_engine_finalize_once(self) -> bool: - tasks = self.take_engine_finalize_batch_nonblocking() - if not tasks: - return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) - try: - jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is None: - continue - jobs_and_items.append((job, task.item)) - if not jobs_and_items: - return False - now = time.perf_counter() - for task in tasks: - job = self.get_engine_job(task.request_id) - if job is not None: - job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) - for job, item in jobs_and_items: - self.update_request_state( - job.engine_request_id, - EngineStatus.FINALIZING, - { - "finish_reason": item.finish_reason, - "semantic_len": int(item.semantic_tokens.shape[0]), - }, - ) - synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) - for job, _ in jobs_and_items: - job.synth_ms += float(synth_ms) - for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): - self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) - except Exception as exc: - self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) - finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.finalize_queue_owner.mark_completed(len(tasks), notify=True) - return True - - def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: - if not bool(policy_snapshot.get("allowed", True)): - return False - dispatch_task = self.dispatch_queue_owner.pop_left() - if dispatch_task is None: - return False - dispatched_at = time.perf_counter() - dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) - dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) - dispatch_task.engine_policy_snapshot = dict(policy_snapshot) - try: - worker_job = self.scheduler_worker.submit( - state=dispatch_task.state, - speed_factor=dispatch_task.speed_factor, - sample_steps=dispatch_task.sample_steps, - media_type=dispatch_task.media_type, - prepare_wall_ms=dispatch_task.prepare_wall_ms, - prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, - done_loop=dispatch_task.done_loop, - done_future=dispatch_task.done_future, - engine_request_id=dispatch_task.engine_request_id, - timeout_sec=dispatch_task.timeout_sec, - skip_capacity_wait=True, - admission_wait_ms_override=0.0, - admission_snapshot_override=dict(worker_state), - engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, - engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, - enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), - ) - dispatch_task.worker_job = worker_job - self.register_engine_job(worker_job) - if self.scheduler_worker.is_engine_decode_control_enabled(): - self.decode_runtime_owner.enqueue_pending_job(worker_job) - self.notify_arbiter() - self.dispatch_queue_owner.mark_completed(1) - return True - except Exception as exc: - dispatch_task.error = str(exc) - self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) - self._notify_dispatch_error(dispatch_task, exc) - return True - - def run_engine_decode_runtime_once(self) -> bool: - if not self.scheduler_worker.is_engine_decode_control_enabled(): - return False - runtime_state = self.snapshot_engine_decode_runtime_state() - pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( - wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 - ) - result = self.scheduler_worker.execute_decode_cycle( - pending_jobs=pending_jobs, - active_batch=self.decode_runtime_owner.get_active_batch(), - external_bookkeeping=True, - ) - prefill_phase = dict(result.get("prefill_phase") or {}) - if prefill_phase.get("error"): - self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) - else: - prefill_jobs = list(prefill_phase.get("pending_jobs") or []) - self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) - self.add_engine_merge_time( - [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), - float(prefill_phase.get("merge_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) - decode_phase = dict(result.get("decode_phase") or {}) - if decode_phase.get("error"): - self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) - else: - self.add_engine_decode_time( - list(decode_phase.get("request_ids") or []), - float(decode_phase.get("decode_elapsed_s", 0.0)), - ) - self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) - self.decode_runtime_owner.set_active_batch(result.get("active_batch")) - if result.get("executed", False): - self.decode_runtime_owner.refresh_state("engine_decode_cycle") - return bool(result.get("executed", False)) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py new file mode 100644 index 00000000..8e66f76e --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import time +from typing import List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineFinalizeStageMixin: + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + self.update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.finalize_queue_owner.enqueue_many(tasks) + self.notify_arbiter() + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + def run_engine_finalize_once(self) -> bool: + tasks = self.take_engine_finalize_batch_nonblocking() + if not tasks: + return False + self.scheduler_worker.begin_finalize_execution(len(tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self.update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(tasks)) + self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py new file mode 100644 index 00000000..43fdd0bf --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import asyncio +from typing import Callable + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineGpuPrepareTask + + +class EngineStageFutureMixin: + def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None: + self._notify_arbiter = notify_arbiter + + def notify_arbiter(self) -> None: + if self._notify_arbiter is not None: + self._notify_arbiter() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py new file mode 100644 index 00000000..bb3e8b06 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepareTask, EngineStatus + + +class EnginePrepareStageMixin: + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec: Any, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self.update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + ) + self.prepare_queue_owner.enqueue(task) + self.notify_arbiter() + return await done_future + + def run_engine_prepare_once(self) -> bool: + task = self.prepare_queue_owner.pop_left() + if task is None: + return False + queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( + self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) + ) + state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + ) + self.prepare_queue_owner.mark_completed(1) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + return True + except Exception as exc: + task.error = str(exc) + self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) + self._notify_prepare_error(task, exc) + return True From 6a822b28c3c54d86d4a984d0a774b2177164230b Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Thu, 12 Mar 2026 01:27:19 +0800 Subject: [PATCH 19/24] Enhance TTS API with improved request handling and asynchronous processing Refactor api_v2.py and api_v3.py to update sampling parameters and weight paths for better clarity and support for v3/v4 vocoders. Introduce new methods in PrepareCoordinator for handling empty text features and improve profiling capabilities. Additionally, update unified engine components to streamline audio processing and state management, enhancing overall performance and maintainability of the TTS system. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 41 ++-- .../TTS_infer_pack/prepare_coordinator.py | 85 ++++++- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 43 +++- .../TTS_infer_pack/unified_engine_api.py | 2 + .../unified_engine_api_direct.py | 232 +++++++++++++++--- .../unified_engine_api_request.py | 10 - .../unified_engine_api_scheduler.py | 151 ++++++++---- .../unified_engine_bridge_delegates.py | 2 + .../unified_engine_bridge_stage.py | 2 + .../unified_engine_component_models.py | 1 + .../unified_engine_component_registry.py | 1 + .../unified_engine_component_runtime.py | 1 + .../TTS_infer_pack/unified_engine_stage.py | 2 + .../unified_engine_stage_dispatch.py | 3 + .../TTS_infer_pack/unified_engine_worker.py | 2 +- .../unified_engine_worker_finalize.py | 41 +++- .../unified_engine_worker_submit.py | 4 + api_v2.py | 10 +- api_v3.py | 10 +- 19 files changed, 494 insertions(+), 149 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c7ae465c..bd811d8a 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -323,7 +323,7 @@ class TTS_Config: assert isinstance(configs, dict) configs_ = deepcopy(self.default_configs) configs_.update(configs) - self.configs: dict = configs_.get("custom", configs_["v2"]) + self.configs: dict = configs_.get("custom", configs_["v2ProPlus"]) self.default_configs = deepcopy(configs_) self.device = self.configs.get("device", torch.device("cpu")) @@ -1872,19 +1872,19 @@ class TTS: self.init_sr_model() if not self.sr_model_not_exist: audio, sr = self.sr_model(audio.unsqueeze(0), sr) - max_audio = np.abs(audio).max() + if isinstance(audio, torch.Tensor): + max_audio = float(torch.abs(audio).max().item()) + else: + max_audio = float(np.abs(audio).max()) if max_audio > 1: audio /= max_audio - audio = (audio * 32768).astype(np.int16) t2 = time.perf_counter() print(f"超采样用时:{t2 - t1:.3f}s") + if isinstance(audio, torch.Tensor): + audio = audio.detach().float().cpu().numpy() else: - # audio = audio.float() * 32768 - # audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy() - - audio = audio.cpu().numpy() - - audio = (audio * 32768).astype(np.int16) + audio = np.asarray(audio) + audio = (audio.reshape(-1) * 32768).astype(np.int16) # try: @@ -2036,20 +2036,23 @@ class TTS: phones: torch.Tensor, prompt_semantic: torch.Tensor, prompt_phones: torch.Tensor, - refer_spec: tuple, + refer_spec: tuple | List[tuple], raw_audio: torch.Tensor, raw_sr: int, speed: float = 1.0, sample_steps: int = 32, ): - refer_audio_spec, audio_tensor = refer_spec + refer_specs = list(refer_spec) if isinstance(refer_spec, list) else [refer_spec] + refer_audio_spec, audio_tensor = refer_specs[0] if not self.configs.use_vocoder: - refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)] + refer_audio_spec_list = [item[0].to(dtype=self.precision, device=self.configs.device) for item in refer_specs] sv_emb = None if self.is_v2pro: - if audio_tensor is None: - raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) - sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device) + sv_emb = [] + for _, audio_tensor_item in refer_specs: + if audio_tensor_item is None: + raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) + sv_emb.append(self.sv_model.compute_embedding3(audio_tensor_item).to(self.configs.device)) return self.vits_model.decode( semantic_tokens, phones, @@ -2075,7 +2078,7 @@ class TTS: self, semantic_tokens_list: List[torch.Tensor], phones_list: List[torch.Tensor], - refer_specs: List[tuple], + refer_specs: List[tuple | List[tuple]], speeds: List[float] | None = None, sample_steps_list: List[int] | None = None, ) -> List[torch.Tensor]: @@ -2118,7 +2121,11 @@ class TTS: semantic_lengths.append(semantic_len) phone_lengths.append(phone_len) - refer_audio_spec, audio_tensor = refer_specs[batch_index] + refer_spec_item = refer_specs[batch_index] + refer_spec_group = list(refer_spec_item) if isinstance(refer_spec_item, list) else [refer_spec_item] + if len(refer_spec_group) != 1: + raise ValueError("batched request-local synthesis 暂不支持单请求多参考音频") + refer_audio_spec, audio_tensor = refer_spec_group[0] refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device)) if self.is_v2pro: if audio_tensor is None: diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 306b1b45..06a5e1b8 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -12,6 +12,7 @@ from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( PreparedTextFeatures, SchedulerRequestSpec, T2SRequestState, + build_empty_text_features, build_request_state_from_parts, normalize_sentence, ) @@ -118,6 +119,21 @@ class PrepareCoordinator: def _prepare_text_cpu(self, text: str, language: str): return self.tts.prepare_text_segments(text, language) + @staticmethod + def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: + feature_dim = 1024 + dtype = None + if reference is not None: + try: + feature_dim = int(reference.bert_features.shape[0]) + dtype = reference.bert_features.dtype + except Exception: + pass + return build_empty_text_features( + feature_dim=int(feature_dim), + dtype=(dtype if dtype is not None else None) or __import__("torch").float32, + ) + def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures: profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)} branch_start = time.perf_counter() @@ -139,6 +155,9 @@ class PrepareCoordinator: return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + if text in [None, ""]: + submit_at = time.perf_counter() + return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) executor = getattr(self.tts, "prepare_text_cpu_executor", None) if executor is None: submit_at = time.perf_counter() @@ -164,19 +183,71 @@ class PrepareCoordinator: prompt_cpu_run_ms: float, target_cpu_run_ms: float, ) -> tuple[ProfiledResult, ProfiledResult]: + prompt_is_empty = len(prompt_segments or []) == 0 if self.text_feature_executor is not None: - prompt_feature_task = asyncio.create_task( - self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms) + target_feature_task = asyncio.create_task(self._run_text_feature_stage(target_segments, None, target_cpu_run_ms)) + if not prompt_is_empty: + prompt_feature_task = asyncio.create_task(self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms)) + return await asyncio.gather(prompt_feature_task, target_feature_task) + target_profiled = await target_feature_task + submit_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=self._build_empty_text_features_like(target_profiled.result), + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), ) - target_feature_task = asyncio.create_task( - self._run_text_feature_stage(target_segments, None, target_cpu_run_ms) - ) - return await asyncio.gather(prompt_feature_task, target_feature_task) + return prompt_profiled, target_profiled - prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} submit_at = time.perf_counter() started_at = float(submit_at) + if prompt_is_empty: + target_result_raw = await self.tts.build_text_features_from_segments_async( + target_segments, + profile=target_profile, + ) + prompt_result = self._build_empty_text_features_like( + PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + ) + finished_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + ) + target_result = PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > target_profiled.finished_at: + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( prompt_segments, target_segments, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 8aabd286..78dbfc36 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,6 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field +import os from pathlib import Path import time from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -35,6 +36,7 @@ class SchedulerRequestSpec: temperature: float repetition_penalty: float early_stop_num: int + aux_ref_audio_paths: List[str] = field(default_factory=list) ready_step: int = 0 @@ -54,6 +56,7 @@ class T2SRequestState: all_bert_features: torch.Tensor prompt_semantic: torch.LongTensor refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] raw_audio: torch.Tensor raw_sr: int top_k: int @@ -113,6 +116,21 @@ class PreparedTextFeatures: cpu_preprocess_ms: float +def build_empty_text_features( + *, + feature_dim: int = 1024, + dtype: torch.dtype = torch.float32, +) -> PreparedTextFeatures: + return PreparedTextFeatures( + phones=[], + bert_features=torch.empty((int(feature_dim), 0), dtype=dtype), + norm_text="", + profile={"cpu_preprocess_ms": 0.0, "bert_total_ms": 0.0}, + total_ms=0.0, + cpu_preprocess_ms=0.0, + ) + + def normalize_sentence(text: str, language: str) -> str: text = text.strip("\n").strip() if not text: @@ -171,6 +189,14 @@ def build_request_state_from_parts( bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic = ref_audio_bundle["prompt_semantic"].long() spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [] + for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []): + if aux_ref_audio_path in [None, ""]: + continue + if not os.path.exists(str(aux_ref_audio_path)): + continue + aux_spec_audio, aux_audio_16k, _, _ = tts.extract_ref_spec(str(aux_ref_audio_path)) + aux_refer_specs.append((aux_spec_audio, aux_audio_16k)) raw_audio = ref_audio_bundle["raw_audio"] raw_sr = int(ref_audio_bundle["raw_sr"]) prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) @@ -182,9 +208,9 @@ def build_request_state_from_parts( phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device) prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device) all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device) - all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to( - dtype=tts.precision, device=tts.configs.device - ) + prompt_bert_features = prompt_result.bert_features.to(dtype=tts.precision, device=tts.configs.device) + target_bert_features = target_result.bert_features.to(dtype=tts.precision, device=tts.configs.device) + all_bert_features = torch.cat([prompt_bert_features, target_bert_features], dim=1) _sync_device(device) tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 @@ -280,6 +306,7 @@ def build_request_state_from_parts( all_bert_features=all_bert_features, prompt_semantic=prompt_semantic, refer_spec=(spec_audio, audio_16k), + aux_refer_specs=aux_refer_specs, raw_audio=raw_audio, raw_sr=raw_sr, top_k=spec.top_k, @@ -301,10 +328,16 @@ def prepare_request_state( prepare_sync_start = time.perf_counter() prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) text = spec.text.strip("\n") - prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang) target_result = prepare_text_features(tts, text, spec.text_lang) if target_result.phones is None: raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + if prompt_text in [None, ""]: + prompt_result = build_empty_text_features( + feature_dim=int(target_result.bert_features.shape[0]), + dtype=target_result.bert_features.dtype, + ) + else: + prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang) ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) return build_request_state_from_parts( tts=tts, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py index ca372d5d..0895cef4 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -119,6 +119,7 @@ class EngineApiFacade: speed_factor: float, sample_steps: int, media_type: str, + super_sampling: bool, prepare_wall_ms: float, prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None, @@ -131,6 +132,7 @@ class EngineApiFacade: speed_factor=speed_factor, sample_steps=sample_steps, media_type=media_type, + super_sampling=super_sampling, prepare_wall_ms=prepare_wall_ms, prepare_profile_total_ms=prepare_profile_total_ms, done_loop=done_loop, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py index b2a308df..f55da45d 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +import queue +import threading import time import uuid from io import BytesIO @@ -122,6 +124,173 @@ class EngineApiDirectFlow: payload["response_streaming"] = False return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") + async def _execute_single_segment_scheduler_job( + self, + normalized: NormalizedEngineRequest, + *, + segment_request: NormalizedEngineRequest, + ) -> tuple[SchedulerPendingJob, Dict[str, Any]]: + spec = self.api._build_scheduler_submit_spec(segment_request) + state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=None, + ) + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + await self.api._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + super_sampling=bool(normalized.super_sampling), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, + ) + timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) + job: SchedulerPendingJob = await asyncio.wait_for(done_future, timeout=timeout_sec) + return job, { + "request_id": spec.request_id, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": dict(state.prepare_profile), + } + + def _iter_scheduler_direct_tts_bytes(self, normalized: NormalizedEngineRequest) -> Generator[bytes, None, None]: + request_start = time.perf_counter() + request_id = normalized.request_id + media_type = normalized.media_type + segment_texts = self._segment_direct_text(normalized) + if not segment_texts: + raise ValueError("text preprocessing returned no valid segments") + chunk_queue: queue.Queue[object] = queue.Queue(maxsize=8) + done_marker = object() + + async def _produce_chunks() -> None: + self.api._update_request_state( + request_id, + EngineStatus.CPU_PREPARING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, + ) + sample_rate: int | None = None + current_media_type = media_type + chunk_count = 0 + stream_total_bytes = 0 + first_chunk_ms: float | None = None + prepare_profiles: List[Dict[str, Any]] = [] + worker_profiles: List[Dict[str, Any]] = [] + try: + for segment_index, segment_text in enumerate(segment_texts): + segment_request = self._build_segment_request( + normalized, + request_id=f"{request_id}_seg_{segment_index:03d}", + text=segment_text, + ) + self.api._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + { + "backend": "scheduler_v1_direct", + "backend_mode": "scheduler_v1_direct", + "segment_index": segment_index, + "segment_count": len(segment_texts), + }, + ) + job, prepare_profile = await self._execute_single_segment_scheduler_job( + normalized, + segment_request=segment_request, + ) + prepare_profiles.append(prepare_profile) + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + worker_profiles.append(dict(job.result)) + if sample_rate is None: + sample_rate = int(job.sample_rate) + first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + self.api._update_request_state( + request_id, + EngineStatus.STREAMING, + { + "backend": "scheduler_v1_direct", + "backend_mode": "scheduler_v1_direct", + "sample_rate": int(sample_rate), + }, + ) + if media_type == "wav": + header = wave_header_chunk(sample_rate=int(sample_rate)) + chunk_count += 1 + stream_total_bytes += len(header) + chunk_queue.put(header) + current_media_type = "raw" + packed_chunk = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), current_media_type).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_chunk) + chunk_queue.put(packed_chunk) + if segment_index + 1 < len(segment_texts): + silence_samples = int(float(normalized.fragment_interval) * float(job.sample_rate)) + if silence_samples > 0: + silence_chunk = np.zeros(silence_samples, dtype=np.int16) + packed_silence = pack_audio( + BytesIO(), silence_chunk, int(job.sample_rate), current_media_type + ).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_silence) + chunk_queue.put(packed_silence) + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + chunk_queue.put(exc) + else: + self.api._merge_request_state_profile( + request_id, + { + "prepare_aggregate": self.api._aggregate_numeric_dicts( + [item["prepare_profile"] for item in prepare_profiles] + ), + "engine_policy_wait_ms": sum( + float(item.get("engine_policy_wait_ms", 0.0)) for item in worker_profiles + ), + "engine_dispatch_wait_ms": sum( + float(item.get("engine_dispatch_wait_ms", 0.0)) for item in worker_profiles + ), + }, + ) + direct_profile = self.api._build_direct_scheduler_profile( + backend="scheduler_v1_direct", + request_start=request_start, + response_ready_at=time.perf_counter(), + audio_bytes=stream_total_bytes, + sample_rate=int(sample_rate or 0), + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=0.0, + response_overhead_ms=0.0, + ) + self.api._complete_request_state( + request_id, + dict(direct_profile, streaming_completed=True, first_chunk_ms=first_chunk_ms), + ) + finally: + chunk_queue.put(done_marker) + + producer_thread = threading.Thread(target=lambda: asyncio.run(_produce_chunks()), daemon=True) + producer_thread.start() + while True: + item = chunk_queue.get() + if item is done_marker: + break + if isinstance(item, Exception): + raise item + yield item + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: request_start = time.perf_counter() request_id = normalized.request_id @@ -129,63 +298,48 @@ class EngineApiDirectFlow: segment_texts = self._segment_direct_text(normalized) if not segment_texts: raise ValueError("text preprocessing returned no valid segments") + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_scheduler_direct_tts_bytes(normalized), + request_id=request_id, + ) self.api._update_request_state( request_id, EngineStatus.CPU_PREPARING, {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, ) - segment_specs = [] - for segment_index, segment_text in enumerate(segment_texts): - segment_request = self._build_segment_request( + segment_requests = [ + self._build_segment_request( normalized, request_id=f"{request_id}_seg_{segment_index:03d}", text=segment_text, ) - segment_specs.append(self.api._build_scheduler_submit_spec(segment_request)) - - prepared_items = await asyncio.gather( - *[ - self.api._prepare_state_via_engine_gpu_queue( - spec=spec, - prepare_submit_at=time.perf_counter(), - engine_request_id=None, - ) - for spec in segment_specs - ] - ) + for segment_index, segment_text in enumerate(segment_texts) + ] prepare_profiles: List[Dict[str, Any]] = [] loop = asyncio.get_running_loop() done_futures: List[asyncio.Future] = [] self.api._update_request_state( request_id, EngineStatus.READY_FOR_PREFILL, - {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_specs)}, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_requests)}, ) - for spec, (state, prepare_exec_started_at, prepare_exec_finished_at) in zip(segment_specs, prepared_items): - prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) - prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) - prepare_profiles.append( - { - "request_id": spec.request_id, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "prepare_profile": dict(state.prepare_profile), - } - ) + prepared_items = await asyncio.gather( + *[ + self._execute_single_segment_scheduler_job( + normalized, + segment_request=segment_request, + ) + for segment_request in segment_requests + ] + ) + for job, prepare_profile in prepared_items: + prepare_profiles.append(prepare_profile) done_future = loop.create_future() + done_future.set_result(job) done_futures.append(done_future) - await self.api._enqueue_prepared_state_for_dispatch( - state=state, - speed_factor=float(normalized.speed_factor), - sample_steps=int(normalized.sample_steps), - media_type=media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - engine_request_id=None, - timeout_sec=normalized.timeout_sec, - ) self.api._update_request_state( request_id, EngineStatus.ACTIVE_DECODE, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py index 974b9612..14d59f12 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py @@ -122,16 +122,6 @@ def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: - if normalized.response_streaming: - if normalized.return_fragment or normalized.fixed_length_chunk: - return "legacy_direct_fragment", "fragment_streaming_mode" - return "legacy_direct_streaming", "streaming_mode" - if is_aux_ref_enabled(normalized.aux_ref_audio_paths): - return "legacy_direct_aux_ref", "aux_ref_audio_paths" - if normalized.super_sampling: - return "legacy_direct_super_sampling", "super_sampling" - if normalized.prompt_text in [None, ""]: - return "legacy_direct_missing_prompt", "missing_prompt_text" return "scheduler_v1_direct", None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py index 1e934f16..cf6677fb 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py @@ -6,7 +6,7 @@ import uuid from io import BytesIO from typing import Any, Dict, List -from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState, run_scheduler_continuous +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution @@ -67,39 +67,58 @@ class EngineApiSchedulerFlow: async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: request_start = time.perf_counter() set_scheduler_seed(seed) - specs = self._build_scheduler_request_specs(request_items) - request_ids = [spec.request_id for spec in specs] - for spec in specs: + normalized_requests: List[NormalizedEngineRequest] = [] + for index, payload in enumerate(request_items): + normalized_requests.append( + self.api._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"req_{index:03d}"), + error_prefix=f"request[{index}] 参数非法: ", + ) + ) + specs = [normalized.to_scheduler_spec() for normalized in normalized_requests] + request_ids = [normalized.request_id for normalized in normalized_requests] + for normalized, spec in zip(normalized_requests, specs): self.api._register_request_state( - request_id=spec.request_id, + request_id=normalized.request_id, api_mode="scheduler_debug", backend="scheduler_debug", - media_type="wav", + media_type=normalized.media_type, response_streaming=False, - meta={ - "text_len": len(spec.text), - "prompt_text_len": len(spec.prompt_text), - "text_lang": spec.text_lang, - "prompt_lang": spec.prompt_lang, - "ref_audio_path": str(spec.ref_audio_path), - "ready_step": int(spec.ready_step), - }, + meta=self.api._build_request_meta(normalized.to_payload()), ) - self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) - self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, None) + self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) + self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None) prepare_started_at = time.perf_counter() + original_worker_max_steps = int(self.api.scheduler_worker.max_steps) + original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps) try: - states = await self.api.scheduler_worker.prepare_states_batch_async(specs) + self.api.scheduler_worker.max_steps = int(max_steps) + self.api.scheduler_worker.decode_executor.max_steps = int(max_steps) + prepared_payloads = await asyncio.gather( + *[ + self.api._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=normalized.request_id, + ) + for normalized, spec in zip(normalized_requests, specs) + ] + ) except Exception as exc: for request_id in request_ids: self.api._fail_request_state(request_id, str(exc)) raise + finally: + self.api.scheduler_worker.max_steps = int(original_worker_max_steps) + self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps) prepare_finished_at = time.perf_counter() prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) + states = [payload[0] for payload in prepared_payloads] for state in states: self.api._update_request_state( state.request_id, - EngineStatus.ACTIVE_DECODE, + EngineStatus.READY_FOR_PREFILL, { "prepare_profile": dict(state.prepare_profile), "norm_text": state.norm_text, @@ -108,7 +127,27 @@ class EngineApiSchedulerFlow: ) decode_started_at = time.perf_counter() try: - finished = run_scheduler_continuous(self.api.tts.t2s_model.model, states, max_steps=int(max_steps)) + loop = asyncio.get_running_loop() + done_futures: List[asyncio.Future] = [] + for normalized, state in zip(normalized_requests, states): + done_future = loop.create_future() + done_futures.append(done_future) + await self.api._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + super_sampling=bool(normalized.super_sampling), + prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)), + prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)), + done_loop=loop, + done_future=done_future, + engine_request_id=normalized.request_id, + timeout_sec=normalized.timeout_sec, + ) + timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]] + timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0 + jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec))) except Exception as exc: for request_id in request_ids: self.api._fail_request_state(request_id, str(exc)) @@ -116,46 +155,63 @@ class EngineApiSchedulerFlow: decode_finished_at = time.perf_counter() decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) - finished_map = {item.request_id: item for item in finished} request_profiles: List[Dict[str, Any]] = [] - for state in states: - item = finished_map.get(state.request_id) - if item is None: + finished: List[Dict[str, Any]] = [] + finish_reason_counts: Dict[str, int] = {} + total_semantic_len = 0 + for state, job in zip(states, jobs): + if job.error is not None: + self.api._fail_request_state(state.request_id, str(job.error)) + raise RuntimeError(str(job.error)) + if job.result is None: self.api._fail_request_state(state.request_id, "scheduler_debug finished without result") - continue - request_profile = self.api._build_scheduler_debug_request_profile( - state=state, - item=item, - batch_request_count=len(states), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - batch_request_total_ms=request_total_ms, - ) - request_profiles.append( + raise RuntimeError(f"{state.request_id} finished without result") + job_result = dict(job.result) + request_profile = { + **job_result, + "backend": "scheduler_debug", + "backend_mode": "scheduler_debug", + "batch_request_count": int(len(states)), + "batch_prepare_wall_ms": float(prepare_batch_wall_ms), + "batch_decode_wall_ms": float(decode_batch_wall_ms), + "batch_request_total_ms": float(request_total_ms), + "prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)), + "prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)), + "prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)), + "prepare_profile": dict(state.prepare_profile), + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + } + request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)}) + self.api._merge_request_state_profile(state.request_id, request_profile) + semantic_len = int(job_result.get("semantic_len", 0)) + finish_reason = str(job_result.get("finish_reason", "unknown")) + finished.append( { "request_id": state.request_id, - "profile": dict(request_profile), + "semantic_len": semantic_len, + "finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))), + "finish_reason": finish_reason, } ) - self.api._complete_request_state( - state.request_id, - dict(request_profile), - ) + finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1 + total_semantic_len += semantic_len return SchedulerDebugExecution( payload={ "message": "success", "request_count": len(states), "max_steps": int(max_steps), - "batch_profile": self.api._build_scheduler_debug_batch_profile( - request_count=len(states), - max_steps=int(max_steps), - prepare_batch_wall_ms=prepare_batch_wall_ms, - decode_batch_wall_ms=decode_batch_wall_ms, - request_total_ms=request_total_ms, - finished_items=finished, - ), + "batch_profile": { + "request_count": int(len(states)), + "max_steps": int(max_steps), + "prepare_batch_wall_ms": float(prepare_batch_wall_ms), + "decode_batch_wall_ms": float(decode_batch_wall_ms), + "request_total_ms": float(request_total_ms), + "total_semantic_len": int(total_semantic_len), + "finish_reason_counts": finish_reason_counts, + }, "requests": self._summarize_scheduler_states(states), - "finished": self._summarize_scheduler_finished(finished), + "finished": finished, "request_profiles": request_profiles, "request_traces": self.api._collect_request_summaries(request_ids), } @@ -222,6 +278,7 @@ class EngineApiSchedulerFlow: speed_factor=float(normalized.speed_factor), sample_steps=int(normalized.sample_steps), media_type=normalized.media_type, + super_sampling=bool(normalized.super_sampling), prepare_wall_ms=prepare_wall_ms, prepare_profile_total_ms=prepare_profile_total_ms, done_loop=loop, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py index 92714750..e2044ec4 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py @@ -149,6 +149,7 @@ class EngineBridgeDelegates: speed_factor: float, sample_steps: int, media_type: str, + super_sampling: bool, prepare_wall_ms: float, prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None, @@ -161,6 +162,7 @@ class EngineBridgeDelegates: speed_factor=speed_factor, sample_steps=sample_steps, media_type=media_type, + super_sampling=super_sampling, prepare_wall_ms=prepare_wall_ms, prepare_profile_total_ms=prepare_profile_total_ms, done_loop=done_loop, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py index 29b5aaab..2a52e779 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py @@ -78,6 +78,7 @@ class EngineStageBridgeFacade: speed_factor: float, sample_steps: int, media_type: str, + super_sampling: bool, prepare_wall_ms: float, prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None, @@ -90,6 +91,7 @@ class EngineStageBridgeFacade: speed_factor=speed_factor, sample_steps=sample_steps, media_type=media_type, + super_sampling=super_sampling, prepare_wall_ms=prepare_wall_ms, prepare_profile_total_ms=prepare_profile_total_ms, done_loop=done_loop, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py index 2c0cc9ac..7b5ea5f8 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py @@ -104,6 +104,7 @@ class NormalizedEngineRequest: temperature=self.temperature, repetition_penalty=self.repetition_penalty, early_stop_num=self.early_stop_num, + aux_ref_audio_paths=list(self.aux_ref_audio_paths or []), ready_step=self.ready_step, ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py index 111ca500..1aaa89c1 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py @@ -303,6 +303,7 @@ class SchedulerPendingJob: speed_factor: float sample_steps: int media_type: str + super_sampling: bool = False admission_wait_ms: float = 0.0 engine_policy_wait_ms: float = 0.0 engine_dispatch_wait_ms: float = 0.0 diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py index db03a0c3..7f4e485f 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -291,6 +291,7 @@ class EngineDispatchTask: speed_factor: float sample_steps: int media_type: str + super_sampling: bool prepare_wall_ms: float prepare_profile_total_ms: float done_loop: asyncio.AbstractEventLoop | None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py index 9aad2fb8..1b872dfa 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -113,6 +113,7 @@ class EngineStageCoordinator: speed_factor: float, sample_steps: int, media_type: str, + super_sampling: bool, prepare_wall_ms: float, prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None, @@ -125,6 +126,7 @@ class EngineStageCoordinator: speed_factor=speed_factor, sample_steps=sample_steps, media_type=media_type, + super_sampling=super_sampling, prepare_wall_ms=prepare_wall_ms, prepare_profile_total_ms=prepare_profile_total_ms, done_loop=done_loop, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py index 53ebd793..644c35f6 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py @@ -16,6 +16,7 @@ class EngineDispatchStageMixin: speed_factor: float, sample_steps: int, media_type: str, + super_sampling: bool, prepare_wall_ms: float, prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None, @@ -29,6 +30,7 @@ class EngineDispatchStageMixin: speed_factor=float(speed_factor), sample_steps=int(sample_steps), media_type=media_type, + super_sampling=bool(super_sampling), prepare_wall_ms=float(prepare_wall_ms), prepare_profile_total_ms=float(prepare_profile_total_ms), done_loop=done_loop, @@ -66,6 +68,7 @@ class EngineDispatchStageMixin: speed_factor=dispatch_task.speed_factor, sample_steps=dispatch_task.sample_steps, media_type=dispatch_task.media_type, + super_sampling=dispatch_task.super_sampling, prepare_wall_ms=dispatch_task.prepare_wall_ms, prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, done_loop=dispatch_task.done_loop, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py index 934ccf52..ae46536f 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py @@ -46,7 +46,7 @@ class UnifiedSchedulerWorker( self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) self.engine_decode_control_enabled = ( - str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "0")).strip().lower() in {"1", "true", "yes", "on"} + str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "1")).strip().lower() in {"1", "true", "yes", "on"} ) self.job_registry = SchedulerJobRegistry(self.condition) self.worker_thread: threading.Thread | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py index 4f5833fd..3a675cbe 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -149,16 +149,25 @@ class WorkerFinalizeExecutor: except Exception: pass + @staticmethod + def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]: + refer_specs = [job.state.refer_spec] + refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or [])) + return refer_specs + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: audio_fragment = self.tts.synthesize_audio_request_local( semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), phones=job.state.phones.detach().clone().unsqueeze(0), prompt_semantic=job.state.prompt_semantic.detach().clone(), prompt_phones=job.state.prompt_phones.detach().clone(), - refer_spec=( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ), + refer_spec=[ + ( + refer_spec_item[0].detach().clone(), + None if refer_spec_item[1] is None else refer_spec_item[1].detach().clone(), + ) + for refer_spec_item in self._collect_job_refer_specs(job) + ], raw_audio=job.state.raw_audio.detach().clone(), raw_sr=int(job.state.raw_sr), speed=float(job.speed_factor), @@ -172,7 +181,7 @@ class WorkerFinalizeExecutor: speed_factor=float(job.speed_factor), split_bucket=False, fragment_interval=0.0, - super_sampling=False, + super_sampling=bool(job.super_sampling), ) def _synthesize_finished_audio_batch( @@ -185,11 +194,14 @@ class WorkerFinalizeExecutor: speeds = [] sample_steps_list = [] for job, _ in jobs_and_items: + refer_spec_group = self._collect_job_refer_specs(job) + if len(refer_spec_group) != 1: + raise ValueError("batched finalize 暂不支持单请求多参考音频") refer_specs.append( - ( - job.state.refer_spec[0].detach().clone(), - None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), - ) + [( + refer_spec_group[0][0].detach().clone(), + None if refer_spec_group[0][1] is None else refer_spec_group[0][1].detach().clone(), + )] ) speeds.append(float(job.speed_factor)) sample_steps_list.append(int(job.sample_steps)) @@ -211,7 +223,7 @@ class WorkerFinalizeExecutor: speed_factor=float(job.speed_factor), split_bucket=False, fragment_interval=0.0, - super_sampling=False, + super_sampling=bool(job.super_sampling), ) ) return results @@ -224,9 +236,12 @@ class WorkerFinalizeExecutor: return 0.0, [] self._sync_device() synth_start = time.perf_counter() - if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: - job, item = jobs_and_items[0] - batch_results = [self._synthesize_finished_audio(job, item)] + if ( + len(jobs_and_items) == 1 + or self.tts.configs.use_vocoder + or any(len(self._collect_job_refer_specs(job)) != 1 for job, _ in jobs_and_items) + ): + batch_results = [self._synthesize_finished_audio(job, item) for job, item in jobs_and_items] else: batch_results = self._synthesize_finished_audio_batch(jobs_and_items) self._sync_device() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py index 1e67f8d3..f1910409 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -78,6 +78,7 @@ class WorkerSubmitLifecycleMixin: speed_factor: float, sample_steps: int, media_type: str, + super_sampling: bool, prepare_wall_ms: float, prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None = None, @@ -97,6 +98,7 @@ class WorkerSubmitLifecycleMixin: speed_factor, sample_steps, media_type, + super_sampling, prepare_wall_ms, prepare_profile_total_ms, done_loop, @@ -172,6 +174,7 @@ class WorkerSubmitLifecycleMixin: speed_factor: float, sample_steps: int, media_type: str, + super_sampling: bool, prepare_wall_ms: float, prepare_profile_total_ms: float, done_loop: asyncio.AbstractEventLoop | None = None, @@ -205,6 +208,7 @@ class WorkerSubmitLifecycleMixin: speed_factor=float(speed_factor), sample_steps=int(sample_steps), media_type=media_type, + super_sampling=bool(super_sampling), admission_wait_ms=float(admission_wait_ms), engine_policy_wait_ms=float(engine_policy_wait_ms), engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), diff --git a/api_v2.py b/api_v2.py index 35b70c8e..9c29989f 100644 --- a/api_v2.py +++ b/api_v2.py @@ -39,8 +39,8 @@ POST: "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. - "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。 + "super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。 "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) @@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights` GET: ``` -http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt ``` RESP: 成功: 返回"success", http code 200 @@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights` GET: ``` -http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth ``` RESP: @@ -211,7 +211,7 @@ async def tts_handle(req: dict): "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline. "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) diff --git a/api_v3.py b/api_v3.py index 1a6457ec..35ecf240 100644 --- a/api_v3.py +++ b/api_v3.py @@ -39,8 +39,8 @@ POST: "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. - "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。 + "super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。 "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) @@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights` GET: ``` -http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt ``` RESP: 成功: 返回"success", http code 200 @@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights` GET: ``` -http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth ``` RESP: @@ -280,7 +280,7 @@ async def tts_handle(req: dict): "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline. "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) From 5cf68a91d3b24299bf05edeec3907e46c4c6c0b5 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Thu, 12 Mar 2026 23:03:33 +0800 Subject: [PATCH 20/24] Add g2pw submodule and enhance TTS processing with AsyncStageGate Introduce a new submodule for g2pw and implement AsyncStageGate in PrepareCoordinator to manage concurrent task inflight limits. Update PrepareTextCpuWorker and PrepareRefSemanticBatchWorker to support asynchronous task submission and completion notifications. Enhance profiling capabilities in TTS to track g2pw processing times, improving overall performance and maintainability of the TTS system. --- .gitmodules | 3 + GPT_SoVITS/TTS_infer_pack/TTS.py | 172 ++++++- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 9 + .../TTS_infer_pack/prepare_coordinator.py | 423 ++++++++++++++---- .../prepare_ref_semantic_batch_worker.py | 42 +- .../TTS_infer_pack/prepare_text_cpu_worker.py | 215 +++++++++ GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 122 ++++- .../unified_engine_api_profile.py | 1 + .../unified_engine_component_runtime.py | 11 + .../unified_engine_stage_prepare.py | 59 ++- .../unified_engine_worker_prepare.py | 20 + .../unified_engine_worker_submit.py | 9 + third_party/g2pw-cu | 1 + 13 files changed, 965 insertions(+), 122 deletions(-) create mode 100644 .gitmodules create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py create mode 160000 third_party/g2pw-cu diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..570e9d7b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/g2pw-cu"] + path = third_party/g2pw-cu + url = https://github.com/baicai-1145/g2pw-cu.git diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index bd811d8a..92c829a1 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,4 +1,5 @@ import gc +import asyncio import concurrent.futures import math import os @@ -42,6 +43,7 @@ from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( PrepareRefSemanticBatchWorker, prepare_prompt_semantic_wav16k, ) +from TTS_infer_pack.prepare_text_cpu_worker import PrepareTextCpuWorker from sv import SV resample_transform_dict = {} @@ -454,18 +456,12 @@ class TTS: self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None self.prepare_text_cpu_workers = max( 0, int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")), ) - self.prepare_text_cpu_executor = ( - concurrent.futures.ThreadPoolExecutor( - max_workers=self.prepare_text_cpu_workers, - thread_name_prefix="prepare-text-cpu", - ) - if self.prepare_text_cpu_workers > 0 - else None - ) + self.prepare_text_cpu_executor = None self._init_models() self.refresh_runtime_components() @@ -488,6 +484,7 @@ class TTS: def refresh_runtime_components(self): self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": self.prepare_bert_batch_worker = PrepareBertBatchWorker( bert_model=self.bert_model, @@ -535,6 +532,92 @@ class TTS: bert_stage_limiter=self.prepare_bert_stage_limiter, bert_batch_worker=self.prepare_bert_batch_worker, ) + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_worker = PrepareTextCpuWorker( + process_fn=lambda text, language: self.text_preprocessor.preprocess_text_segments( + text, + language, + self.configs.version, + ), + worker_count=self.prepare_text_cpu_workers, + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_ADMISSION_POLL_MS", "1")), + admission_controller=self._build_text_cpu_admission_state, + ) + + @staticmethod + def _safe_queue_qsize(executor) -> int | None: + if executor is None: + return None + queue = getattr(executor, "_work_queue", None) + if queue is None or not hasattr(queue, "qsize"): + return None + try: + return int(queue.qsize()) + except Exception: + return None + + def snapshot_prepare_runtime_components(self) -> dict: + return { + "text_cpu": { + "workers": int(self.prepare_text_cpu_workers), + "queue_size": self._safe_queue_qsize(self.prepare_text_cpu_executor), + "enabled": bool(self.prepare_text_cpu_worker is not None or self.prepare_text_cpu_executor is not None), + "worker": ( + None if self.prepare_text_cpu_worker is None else dict(self.prepare_text_cpu_worker.snapshot()) + ), + "admission": self._build_text_cpu_admission_state(), + }, + "bert": { + "stage_limiter": dict(self.prepare_bert_stage_limiter.snapshot()), + "batch_worker": ( + None if self.prepare_bert_batch_worker is None else dict(self.prepare_bert_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_bert_batch_worker is not None), + }, + "ref_semantic": { + "stage_limiter": dict(self.prepare_ref_audio_stage_limiter.snapshot()), + "batch_worker": ( + None + if self.prepare_ref_semantic_batch_worker is None + else dict(self.prepare_ref_semantic_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_ref_semantic_batch_worker is not None), + }, + "text_preprocessor": ( + None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot() + ), + } + + def _build_text_cpu_admission_state(self) -> dict: + bert_pending_soft_max = max( + 0, + int( + os.environ.get( + "GPTSOVITS_PREPARE_TEXT_CPU_BERT_PENDING_SOFT_MAX", + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "32"), + ) + ), + ) + if self.prepare_bert_batch_worker is None or bert_pending_soft_max <= 0: + return { + "blocked": False, + "reason": "", + "bert_pending": 0, + "bert_active_batch_size": 0, + "bert_pending_soft_max": int(bert_pending_soft_max), + } + bert_state = dict(self.prepare_bert_batch_worker.snapshot()) + bert_pending = int(bert_state.get("pending", 0)) + bert_active_batch_size = int(bert_state.get("active_batch_size", 0)) + blocked = bert_pending >= bert_pending_soft_max + return { + "blocked": bool(blocked), + "reason": ("bert_pending" if blocked else ""), + "bert_pending": int(bert_pending), + "bert_active_batch_size": int(bert_active_batch_size), + "bert_pending_soft_max": int(bert_pending_soft_max), + } def _init_models( self, @@ -1040,6 +1123,79 @@ class TTS: }, } + async def extract_ref_audio_bundle_async(self, ref_audio_path: str): + if self.prepare_ref_semantic_batch_worker is None: + return await asyncio.to_thread(self.extract_ref_audio_bundle, ref_audio_path) + + load_start = time.perf_counter() + raw_audio, raw_sr = await asyncio.to_thread(self._load_ref_audio_raw, ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + + prompt_semantic_task = asyncio.create_task( + self.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + + def _build_ref_spec_profile(): + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + return refer_spec, { + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": float(ref_spec_ms), + "audio_stage_slots": float(ref_spec_limiter_stats["slots"]), + "audio_stage_inflight_peak": float(ref_spec_limiter_stats["peak_inflight"]), + } + + ref_spec_task = asyncio.create_task(asyncio.to_thread(_build_ref_spec_profile)) + (prompt_semantic, prompt_semantic_profile), (refer_spec, ref_spec_profile) = await asyncio.gather( + prompt_semantic_task, + ref_spec_task, + ) + + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_profile.get("ref_spec_wait_ms", 0.0) + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_profile.get("audio_stage_slots", 0.0)), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_profile.get("audio_stage_inflight_peak", 0.0)), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": float(load_ms), + "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_slots": float(audio_stage_slots), + "audio_stage_inflight_peak": float(audio_stage_inflight_peak), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float(prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(ref_spec_profile.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profile.get("ref_spec_ms", 0.0)), + "bundle_total_ms": float(load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_profile.get("ref_spec_ms", 0.0)), + }, + } + def extract_text_features(self, text: str, language: str, profile: dict | None = None): return self.text_preprocessor.segment_and_extract_feature_for_text( text, language, self.configs.version, profile=profile diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 6bee49be..01a8ea4d 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -118,6 +118,15 @@ class TextPreprocessor: self.bert_stage_limiter = bert_stage_limiter self.bert_batch_worker = bert_batch_worker + def snapshot(self) -> Dict[str, object]: + return { + "device": str(self.device), + "bert_stage_limiter": ( + None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot()) + ), + "bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()), + } + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 06a5e1b8..71134268 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -24,6 +24,7 @@ class ProfiledResult: submit_at: float started_at: float finished_at: float + profile: Dict[str, float] | None = None @property def queue_ms(self) -> float: @@ -48,6 +49,52 @@ class PreparedCpuStage: target_cpu_profiled: ProfiledResult +class AsyncStageGate: + def __init__(self, max_inflight: int, poll_ms: int = 1): + self.max_inflight = max(0, int(max_inflight)) + self.lock = threading.Lock() + self.poll_s = max(0.0005, float(max(1, int(poll_ms))) / 1000.0) + self.inflight = 0 + self.peak_inflight = 0 + self.total_entered = 0 + self.total_wait_ms = 0.0 + self.wait_peak_ms = 0.0 + + async def acquire(self) -> Dict[str, float]: + wait_start = time.perf_counter() + while True: + with self.lock: + if self.max_inflight <= 0 or self.inflight < self.max_inflight: + self.inflight += 1 + self.total_entered += 1 + wait_ms = max(0.0, (time.perf_counter() - wait_start) * 1000.0) + self.total_wait_ms += float(wait_ms) + self.wait_peak_ms = max(self.wait_peak_ms, float(wait_ms)) + self.peak_inflight = max(self.peak_inflight, self.inflight) + return { + "wait_ms": float(wait_ms), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "max_inflight": float(self.max_inflight), + } + await asyncio.sleep(self.poll_s) + + def release(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, float]: + with self.lock: + return { + "max_inflight": float(self.max_inflight), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "total_entered": float(self.total_entered), + "total_wait_ms": float(self.total_wait_ms), + "wait_peak_ms": float(self.wait_peak_ms), + } + + class PrepareCoordinator: def __init__(self, tts: Any): self.tts = tts @@ -59,7 +106,8 @@ class PrepareCoordinator: and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0" ) self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0"))) - self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None + gate_poll_ms = int(os.environ.get("GPTSOVITS_PREPARE_GATE_POLL_MS", "1")) + self._inflight_gate = AsyncStageGate(self.max_inflight, poll_ms=gate_poll_ms) self.text_feature_workers = 0 self.text_feature_executor = None if not self.use_async_text_feature_path: @@ -81,6 +129,29 @@ class PrepareCoordinator: max_workers=self.ref_audio_workers, thread_name_prefix="prepare-ref-audio", ) + text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0)) + text_feature_gate_default = max(0, int(self.text_feature_workers)) + ref_audio_gate_default = max(0, int(self.ref_audio_workers)) + self.text_cpu_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))), + poll_ms=gate_poll_ms, + ) + self.text_feature_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_audio_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_load_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_LOAD_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_spec_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_SPEC_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) def _mark_enter(self) -> Tuple[int, int]: with self.lock: @@ -94,15 +165,29 @@ class PrepareCoordinator: with self.lock: self.inflight = max(0, self.inflight - 1) - def snapshot(self) -> Dict[str, int]: + def snapshot(self) -> Dict[str, Any]: with self.lock: - return { + snapshot: Dict[str, Any] = { "inflight": int(self.inflight), "peak_inflight": int(self.peak_inflight), "max_inflight": int(self.max_inflight), "text_feature_workers": int(self.text_feature_workers), "ref_audio_workers": int(self.ref_audio_workers), } + runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None) + if callable(runtime_snapshot_fn): + try: + snapshot["prepare_runtime_state"] = runtime_snapshot_fn() + except Exception: + snapshot["prepare_runtime_state"] = None + snapshot["prepare_stage_gates"] = { + "text_cpu": self.text_cpu_gate.snapshot(), + "text_feature": self.text_feature_gate.snapshot(), + "ref_audio": self.ref_audio_gate.snapshot(), + "ref_load": self.ref_load_gate.snapshot(), + "ref_spec": self.ref_spec_gate.snapshot(), + } + return snapshot @staticmethod def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult: @@ -119,6 +204,12 @@ class PrepareCoordinator: def _prepare_text_cpu(self, text: str, language: str): return self.tts.prepare_text_segments(text, language) + def _load_ref_audio_raw(self, ref_audio_path: str): + return self.tts._load_ref_audio_raw(ref_audio_path) + + def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): + return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + @staticmethod def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: feature_dim = 1024 @@ -155,17 +246,54 @@ class PrepareCoordinator: return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + await self.text_cpu_gate.acquire() if text in [None, ""]: - submit_at = time.perf_counter() - return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) + try: + submit_at = time.perf_counter() + return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) + finally: + self.text_cpu_gate.release() + text_cpu_worker = getattr(self.tts, "prepare_text_cpu_worker", None) executor = getattr(self.tts, "prepare_text_cpu_executor", None) - if executor is None: - submit_at = time.perf_counter() - return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) - return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + try: + if text_cpu_worker is not None: + submit_at = time.perf_counter() + result, worker_profile = await text_cpu_worker.submit_async(text, language) + started_at = float( + submit_at + + ( + float(worker_profile.get("text_cpu_admission_wait_ms", 0.0)) + + float(worker_profile.get("text_cpu_queue_wait_ms", 0.0)) + ) + / 1000.0 + ) + finished_at = float(started_at + float(worker_profile.get("text_cpu_run_ms", 0.0)) / 1000.0) + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=finished_at, + profile=dict(worker_profile), + ) + if executor is None: + submit_at = time.perf_counter() + return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) + return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + finally: + self.text_cpu_gate.release() async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult: - return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms) + await self.text_feature_gate.acquire() + try: + return await self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + prepared_segments, + language, + cpu_run_ms, + ) + finally: + self.text_feature_gate.release() @staticmethod def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: @@ -199,16 +327,34 @@ class PrepareCoordinator: ) return prompt_profiled, target_profiled + await self.text_feature_gate.acquire() target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} submit_at = time.perf_counter() started_at = float(submit_at) - if prompt_is_empty: - target_result_raw = await self.tts.build_text_features_from_segments_async( - target_segments, - profile=target_profile, - ) - prompt_result = self._build_empty_text_features_like( - PreparedTextFeatures( + try: + if prompt_is_empty: + target_result_raw = await self.tts.build_text_features_from_segments_async( + target_segments, + profile=target_profile, + ) + prompt_result = self._build_empty_text_features_like( + PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + ) + finished_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + ) + target_result = PreparedTextFeatures( phones=target_result_raw[0], bert_features=target_result_raw[1], norm_text=target_result_raw[2], @@ -216,13 +362,37 @@ class PrepareCoordinator: total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), cpu_preprocess_ms=float(target_cpu_run_ms), ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > target_profiled.finished_at: + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, ) finished_at = time.perf_counter() - prompt_profiled = ProfiledResult( - result=prompt_result, - submit_at=float(submit_at), - started_at=float(submit_at), - finished_at=float(submit_at), + + prompt_result = PreparedTextFeatures( + phones=prompt_result_raw[0], + bert_features=prompt_result_raw[1], + norm_text=prompt_result_raw[2], + profile=prompt_profile, + total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), + cpu_preprocess_ms=float(prompt_cpu_run_ms), ) target_result = PreparedTextFeatures( phones=target_result_raw[0], @@ -232,79 +402,152 @@ class PrepareCoordinator: total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), cpu_preprocess_ms=float(target_cpu_run_ms), ) + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), + ) target_profiled = ProfiledResult( result=target_result, submit_at=float(submit_at), started_at=started_at, finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), ) - if finished_at > target_profiled.finished_at: + if finished_at > prompt_profiled.finished_at: + prompt_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(prompt_profile), + (finished_at - submit_at) * 1000.0, + ) target_result.profile["bert_total_ms"] = max( self._estimate_text_feature_run_ms(target_profile), (finished_at - submit_at) * 1000.0, ) else: + prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) return prompt_profiled, target_profiled - - prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} - prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( - prompt_segments, - target_segments, - prompt_profile=prompt_profile, - target_profile=target_profile, - ) - finished_at = time.perf_counter() - - prompt_result = PreparedTextFeatures( - phones=prompt_result_raw[0], - bert_features=prompt_result_raw[1], - norm_text=prompt_result_raw[2], - profile=prompt_profile, - total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), - cpu_preprocess_ms=float(prompt_cpu_run_ms), - ) - target_result = PreparedTextFeatures( - phones=target_result_raw[0], - bert_features=target_result_raw[1], - norm_text=target_result_raw[2], - profile=target_profile, - total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), - cpu_preprocess_ms=float(target_cpu_run_ms), - ) - prompt_profiled = ProfiledResult( - result=prompt_result, - submit_at=float(submit_at), - started_at=started_at, - finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), - ) - target_profiled = ProfiledResult( - result=target_result, - submit_at=float(submit_at), - started_at=started_at, - finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), - ) - if finished_at > prompt_profiled.finished_at: - prompt_result.profile["bert_total_ms"] = max( - self._estimate_text_feature_run_ms(prompt_profile), - (finished_at - submit_at) * 1000.0, - ) - target_result.profile["bert_total_ms"] = max( - self._estimate_text_feature_run_ms(target_profile), - (finished_at - submit_at) * 1000.0, - ) - else: - prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) - target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) - return prompt_profiled, target_profiled + finally: + self.text_feature_gate.release() async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: - return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + submit_at = time.perf_counter() + started_at = float(submit_at) + + await self.ref_load_gate.acquire() + try: + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + finally: + self.ref_load_gate.release() + + raw_audio, raw_sr = load_profiled.result + prompt_semantic_task = asyncio.create_task( + self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + await self.ref_spec_gate.acquire() + try: + ref_spec_task = asyncio.create_task( + self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) + ) + (prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather( + prompt_semantic_task, + ref_spec_task, + ) + finally: + self.ref_spec_gate.release() + + refer_spec = ref_spec_profiled.result + limiter_snapshot = ( + self.tts.prepare_ref_audio_stage_limiter.snapshot() + if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None + else {} + ) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + audio_stage_wait_ms = ( + float(load_profiled.queue_ms) + + float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + + float(ref_spec_profiled.queue_ms) + ) + finished_at = time.perf_counter() + result = { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_queue_ms": float(load_profiled.queue_ms), + "audio_load_ms": float(load_profiled.run_ms), + "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_slots": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(limiter_snapshot.get("slots", 0.0)), + ) + ), + "audio_stage_inflight_peak": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(limiter_snapshot.get("peak_inflight", 0.0)), + ) + ), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_profiled.queue_ms), + "ref_spec_ms": float(ref_spec_profiled.run_ms), + "bundle_total_ms": float( + load_profiled.queue_ms + + load_profiled.run_ms + + prompt_semantic_ms + + ref_spec_profiled.queue_ms + + ref_spec_profiled.run_ms + ), + }, + } + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(finished_at), + ) + + await self.ref_audio_gate.acquire() + try: + if hasattr(self.tts, "extract_ref_audio_bundle_async"): + submit_at = time.perf_counter() + started_at = time.perf_counter() + result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path) + finished_at = time.perf_counter() + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=float(started_at), + finished_at=float(finished_at), + ) + return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + finally: + self.ref_audio_gate.release() def _release_split_stage_slot(self) -> None: self._mark_leave() - if self._inflight_semaphore is not None: - self._inflight_semaphore.release() + self._inflight_gate.release() async def prepare_cpu_stage_profiled_async( self, @@ -312,9 +555,11 @@ class PrepareCoordinator: prepare_submit_at: float, ) -> PreparedCpuStage: admission_start = time.perf_counter() - if self._inflight_semaphore is not None: - await self._inflight_semaphore.acquire() - prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0) + admission_stats = await self._inflight_gate.acquire() + prepare_admission_wait_ms = max( + float(admission_stats.get("wait_ms", 0.0)), + (time.perf_counter() - admission_start) * 1000.0, + ) current_inflight, peak_inflight = self._mark_enter() prepare_start = time.perf_counter() prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) @@ -382,10 +627,28 @@ class PrepareCoordinator: "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, + "prompt_text_cpu_admission_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "prompt_text_cpu_backpressure_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "prompt_text_cpu_capacity_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, + "text_cpu_admission_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "text_cpu_backpressure_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "text_cpu_capacity_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), "text_feature_queue_ms": target_feature_profiled.queue_ms, "text_feature_run_ms": target_feature_profiled.run_ms, "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index 7a1f9a53..64ca133f 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -1,3 +1,4 @@ +import asyncio import threading import time import uuid @@ -51,6 +52,8 @@ class RefSemanticTask: task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None result_prompt_semantic: torch.Tensor | None = None error: Exception | None = None profile: Dict[str, float] = field(default_factory=dict) @@ -115,6 +118,41 @@ class PrepareRefSemanticBatchWorker: assert task.result_prompt_semantic is not None return task.result_prompt_semantic, dict(task.profile) + async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = RefSemanticTask( + raw_audio=raw_audio, + raw_sr=int(raw_sr), + done_loop=loop, + done_future=loop.create_future(), + ) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + return await task.done_future + + @staticmethod + def _resolve_done_future(task: RefSemanticTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_prompt_semantic is not None + task.done_future.set_result((task.result_prompt_semantic, dict(task.profile))) + + def _notify_task_done(self, task: RefSemanticTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + def snapshot(self) -> Dict[str, int]: with self.condition: return { @@ -247,7 +285,7 @@ class PrepareRefSemanticBatchWorker: for task in batch: if task.result_prompt_semantic is not None: task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) - task.done_event.set() + self._notify_task_done(task) def _run_loop(self) -> None: while True: @@ -257,6 +295,6 @@ class PrepareRefSemanticBatchWorker: except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc - task.done_event.set() + self._notify_task_done(task) finally: self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py new file mode 100644 index 00000000..b7c985f0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py @@ -0,0 +1,215 @@ +import asyncio +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, Tuple + + +@dataclass +class TextCpuTask: + text: str + language: str + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + backpressure_wait_ms: float = 0.0 + capacity_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 + done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None + result: Any = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareTextCpuWorker: + def __init__( + self, + process_fn: Callable[[str, str], Any], + worker_count: int, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None, + ) -> None: + self.process_fn = process_fn + self.worker_count = max(1, int(worker_count)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + self.admission_controller = admission_controller + + self.condition = threading.Condition() + self.pending_tasks: Deque[TextCpuTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.active_workers = 0 + self.active_workers_peak = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 + self.backpressure_wait_total_ms = 0.0 + self.backpressure_wait_peak_ms = 0.0 + self.capacity_wait_total_ms = 0.0 + self.capacity_wait_peak_ms = 0.0 + self.backpressure_blocked_total = 0 + + self.worker_threads = [ + threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True) + for index in range(self.worker_count) + ] + for thread in self.worker_threads: + thread.start() + + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks + + def _get_admission_state(self) -> Dict[str, float | int | bool]: + if self.admission_controller is None: + return {"blocked": False} + try: + state = dict(self.admission_controller() or {}) + except Exception: + return {"blocked": False} + state["blocked"] = bool(state.get("blocked", False)) + return state + + def _record_enqueue_locked( + self, + task: TextCpuTask, + *, + admission_wait_ms: float, + backpressure_wait_ms: float, + capacity_wait_ms: float, + ) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms)) + task.capacity_wait_ms = float(max(0.0, capacity_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + self.backpressure_wait_total_ms += task.backpressure_wait_ms + self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms) + self.capacity_wait_total_ms += task.capacity_wait_ms + self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms) + if task.backpressure_wait_ms > 0.0: + self.backpressure_blocked_total += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + async def _enqueue_task_async(self, task: TextCpuTask) -> None: + admission_started = time.perf_counter() + backpressure_wait_ms = 0.0 + capacity_wait_ms = 0.0 + while True: + loop_start = time.perf_counter() + admission_state = self._get_admission_state() + blocked = bool(admission_state.get("blocked", False)) + with self.condition: + if not blocked and self._can_enqueue_locked(): + self._record_enqueue_locked( + task, + admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0, + backpressure_wait_ms=backpressure_wait_ms, + capacity_wait_ms=capacity_wait_ms, + ) + return + await asyncio.sleep(self.admission_poll_s) + waited_ms = (time.perf_counter() - loop_start) * 1000.0 + if blocked: + backpressure_wait_ms += waited_ms + else: + capacity_wait_ms += waited_ms + + def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + task = TextCpuTask(text=str(text), language=str(language)) + asyncio.run(self._enqueue_task_async(task)) + task.done_event.wait() + if task.error is not None: + raise task.error + return task.result, dict(task.profile) + + async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = TextCpuTask( + text=str(text), + language=str(language), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + + @staticmethod + def _resolve_done_future(task: TextCpuTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + task.done_future.set_result((task.result, dict(task.profile))) + + def _notify_task_done(self, task: TextCpuTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + + def snapshot(self) -> Dict[str, int | float]: + with self.condition: + return { + "worker_count": int(self.worker_count), + "pending": int(len(self.pending_tasks)), + "pending_peak": int(self.pending_peak), + "active_workers": int(self.active_workers), + "active_workers_peak": int(self.active_workers_peak), + "total_submitted": int(self.total_submitted), + "total_finished": int(self.total_finished), + "max_pending_tasks": int(self.max_pending_tasks), + "admission_wait_total_ms": float(self.admission_wait_total_ms), + "admission_wait_peak_ms": float(self.admission_wait_peak_ms), + "backpressure_wait_total_ms": float(self.backpressure_wait_total_ms), + "backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms), + "capacity_wait_total_ms": float(self.capacity_wait_total_ms), + "capacity_wait_peak_ms": float(self.capacity_wait_peak_ms), + "backpressure_blocked_total": int(self.backpressure_blocked_total), + } + + def _run_loop(self) -> None: + while True: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + task = self.pending_tasks.popleft() + self.active_workers += 1 + self.active_workers_peak = max(self.active_workers_peak, self.active_workers) + started_at = time.perf_counter() + try: + task.result = self.process_fn(task.text, task.language) + task.profile = { + "text_cpu_admission_wait_ms": float(task.admission_wait_ms), + "text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms), + "text_cpu_capacity_wait_ms": float(task.capacity_wait_ms), + "text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0), + "text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + finally: + with self.condition: + self.active_workers = max(0, self.active_workers - 1) + self.total_finished += 1 + self.condition.notify_all() + self._notify_task_done(task) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 78dbfc36..ed465f69 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -421,6 +421,55 @@ def _iter_contiguous_sampling_groups( return groups +def _uniform_sampling_group_key(active_batch: T2SActiveBatch) -> Optional[Tuple[int, float, float, float, bool]]: + if not active_batch.states: + return None + if active_batch.step_indices.numel() <= 0: + return None + first_step_index = int(active_batch.step_indices[0].item()) + if bool((active_batch.step_indices != first_step_index).any().item()): + return None + first_state = active_batch.states[0] + first_key = _sampling_group_key( + top_k=first_state.top_k, + top_p=first_state.top_p, + temperature=first_state.temperature, + repetition_penalty=first_state.repetition_penalty, + trim_eos=first_step_index < 11, + ) + for state in active_batch.states[1:]: + if ( + state.top_k != first_state.top_k + or state.top_p != first_state.top_p + or state.temperature != first_state.temperature + or state.repetition_penalty != first_state.repetition_penalty + ): + return None + return first_key + + +def _batched_sample_uniform( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_key: Tuple[int, float, float, float, bool], +) -> Tuple[torch.Tensor, torch.Tensor]: + top_k, top_p, temperature, repetition_penalty, trim_eos = sampling_key + sample_logits = logits[:, :-1] if trim_eos else logits + padded_histories, history_mask = _pad_token_sequences(histories) + probs = logits_to_probs( + logits=sample_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + sampled = multinomial_sample_one_no_sync(probs) + argmax_tokens = torch.argmax(sample_logits, dim=-1) + return sampled, argmax_tokens + + def _batched_sample_by_group( logits: torch.Tensor, histories: Sequence[torch.LongTensor], @@ -594,27 +643,59 @@ def _sample_per_request( keep_indices: List[int] = [] updated_sequences: List[torch.LongTensor] = [] - sampling_keys = [ - _sampling_group_key( - top_k=state.top_k, - top_p=state.top_p, - temperature=state.temperature, - repetition_penalty=state.repetition_penalty, - trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, + uniform_sampling_key = _uniform_sampling_group_key(active_batch) + sampled_items: List[torch.Tensor] + argmax_tokens: List[int] + sampled_token_tensor: Optional[torch.Tensor] = None + argmax_token_tensor: Optional[torch.Tensor] = None + if uniform_sampling_key is not None: + sampled_tensor, argmax_tensor = _batched_sample_uniform( + logits=logits, + histories=active_batch.y_sequences, + sampling_key=uniform_sampling_key, + ) + sampled_token_tensor = sampled_tensor.view(-1) + argmax_token_tensor = argmax_tensor.view(-1) + if ( + all(state.early_stop_num == -1 for state in active_batch.states) + and int(active_batch.step_indices[0].item()) + 1 < max_steps + and not bool(sampled_token_tensor.eq(model.EOS).any().item()) + and not bool(argmax_token_tensor.eq(model.EOS).any().item()) + ): + return ( + [], + list(range(len(active_batch.states))), + [torch.cat([history, sampled_token_tensor[index : index + 1]], dim=0) for index, history in enumerate(active_batch.y_sequences)], + ) + sampled_items = [sampled_tensor[index : index + 1] for index in range(sampled_tensor.shape[0])] + argmax_tokens = [int(item) for item in argmax_tensor.tolist()] + else: + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, + ) + for batch_index, state in enumerate(active_batch.states) + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, ) - for batch_index, state in enumerate(active_batch.states) - ] - sampled_items, argmax_tokens = _batched_sample_by_group( - logits=logits, - histories=active_batch.y_sequences, - sampling_keys=sampling_keys, - ) for batch_index, state in enumerate(active_batch.states): step_index = int(active_batch.step_indices[batch_index].item()) current_history = active_batch.y_sequences[batch_index] - sampled = sampled_items[batch_index] - sampled_token = int(sampled[0, 0].item()) - argmax_token = argmax_tokens[batch_index] + if sampled_token_tensor is not None and argmax_token_tensor is not None: + sampled = sampled_token_tensor[batch_index : batch_index + 1] + sampled_token = int(sampled_token_tensor[batch_index].item()) + argmax_token = int(argmax_token_tensor[batch_index].item()) + else: + sampled = sampled_items[batch_index] + sampled_token = int(sampled[0, 0].item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None @@ -690,6 +771,13 @@ def decode_one_step( finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) if len(keep_indices) == 0: return None, finished_items + if len(keep_indices) == len(active_batch.request_ids): + active_batch.y_sequences = updated_sequences + active_batch.step_indices = active_batch.step_indices + 1 + if not was_prefill and active_batch.kv_lens is not None: + active_batch.kv_lens = active_batch.kv_lens + 1 + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + return active_batch, finished_items device = logits.device keep_tensor = torch.LongTensor(keep_indices).to(device) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py index f950c68d..e31c5dfe 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py @@ -314,6 +314,7 @@ def build_scheduler_submit_headers( "X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), "X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), + "X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))), "X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), "X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), "X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py index 7f4e485f..15eedeca 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -44,6 +44,16 @@ class EngineTaskQueueOwner: return None return self.queue.popleft() + def pop_left_many(self, max_items: int) -> List[Any]: + limit = max(1, int(max_items)) + with self.condition: + if not self.queue: + return [] + selected: List[Any] = [] + while self.queue and len(selected) < limit: + selected.append(self.queue.popleft()) + return selected + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: if count <= 0: return @@ -315,6 +325,7 @@ class EngineGpuPrepareTask: engine_request_id: str | None enqueue_time: float queue_wait_ms: float = 0.0 + admission_wait_ms: float = 0.0 error: str | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py index bb3e8b06..b9095d2c 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import time from typing import Any @@ -9,6 +10,19 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare class EnginePrepareStageMixin: + async def _wait_prepare_queue_admission(self) -> float: + soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) + if soft_max <= 0: + return 0.0 + poll_s = max( + 0.0005, + float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, + ) + wait_start = time.perf_counter() + while self.prepare_queue_owner.waiting_count() >= soft_max: + await asyncio.sleep(poll_s) + return max(0.0, (time.perf_counter() - wait_start) * 1000.0) + async def prepare_state_via_engine_gpu_queue( self, *, @@ -16,12 +30,14 @@ class EnginePrepareStageMixin: prepare_submit_at: float, engine_request_id: str | None, ) -> tuple[T2SRequestState, float, float]: + prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission() cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) if engine_request_id not in [None, ""]: self.update_request_state( str(engine_request_id), EngineStatus.GPU_PREPARING, { + "engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms), "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), @@ -37,31 +53,44 @@ class EnginePrepareStageMixin: done_future=done_future, engine_request_id=engine_request_id or spec.request_id, enqueue_time=time.perf_counter(), + admission_wait_ms=float(prepare_queue_admission_wait_ms), ) self.prepare_queue_owner.enqueue(task) self.notify_arbiter() return await done_future def run_engine_prepare_once(self) -> bool: - task = self.prepare_queue_owner.pop_left() - if task is None: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + if not tasks: return False - queue_wait_ms = max(0.0, (time.perf_counter() - task.enqueue_time) * 1000.0) - try: - state, prepare_exec_started_at, prepare_exec_finished_at = asyncio.run( - self.scheduler_worker.prepare_gpu_stage_profiled_async(task.cpu_stage) - ) + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + batch_results = asyncio.run( + self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks]) + ) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + if isinstance(result, Exception): + task.error = str(result) + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + state, prepare_exec_started_at, prepare_exec_finished_at = result + state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms) state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) if task.engine_request_id not in [None, ""]: self.merge_request_state_profile( str(task.engine_request_id), - {"engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms)}, + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms), + "engine_gpu_prepare_batch_size": float(len(tasks)), + }, ) - self.prepare_queue_owner.mark_completed(1) self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) - return True - except Exception as exc: - task.error = str(exc) - self.fail_request_state(task.engine_request_id or task.request_id, str(exc)) - self._notify_prepare_error(task, exc) - return True + completed_count += 1 + self.prepare_queue_owner.mark_completed(completed_count) + return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py index 28da24ee..8b3db1fa 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import time from typing import Callable, Dict, List @@ -32,6 +33,11 @@ class WorkerPrepareExecutor: def get_max_inflight(self) -> int: return int(self.coordinator.snapshot().get("max_inflight", 0)) + def get_batch_policy(self) -> Dict[str, int]: + return { + "prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))), + } + def is_idle(self) -> bool: return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 @@ -69,3 +75,17 @@ class WorkerPrepareExecutor: return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) finally: self._notify_state_change() + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + try: + return list( + await asyncio.gather( + *[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + ) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py index f1910409..e498e9ea 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -32,6 +32,9 @@ class WorkerSubmitLifecycleMixin: def get_finalize_batch_policy(self) -> Dict[str, Any]: return dict(self.finalize_executor.get_batch_policy()) + def get_prepare_batch_policy(self) -> Dict[str, int]: + return dict(self.prepare_executor.get_batch_policy()) + def get_decode_runtime_counters(self) -> Dict[str, int]: with self.condition: return self.decode_runtime_tracker.get_counters() @@ -258,3 +261,9 @@ class WorkerSubmitLifecycleMixin: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) diff --git a/third_party/g2pw-cu b/third_party/g2pw-cu new file mode 160000 index 00000000..a53cf4ee --- /dev/null +++ b/third_party/g2pw-cu @@ -0,0 +1 @@ +Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98 From 17cb2e5acfc6d1c4c2aae7a31b296bb4eb4cd21e Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Thu, 12 Mar 2026 23:04:39 +0800 Subject: [PATCH 21/24] Implement G2PW processing enhancements in TTS framework Add support for G2PW processing in the TTS system by introducing new methods and classes for handling G2PW segments. Update PrepareCoordinator to manage G2PW worker threads and integrate G2PW profiling into the existing framework. Enhance text preprocessing to identify segments requiring G2PW and streamline the resolution of these segments. This update improves the overall performance and maintainability of the TTS system by optimizing the handling of Chinese text processing. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 15 + GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 57 ++ .../TTS_infer_pack/prepare_coordinator.py | 143 +++- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 30 + .../TTS_infer_pack/text_cpu_preprocess.py | 18 +- GPT_SoVITS/text/chinese2.py | 322 +++++---- GPT_SoVITS/text/g2pw/cuda_api.py | 670 ++++++++++++++++++ GPT_SoVITS/text/g2pw/g2pw.py | 37 +- GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp | 183 +++++ GPT_SoVITS/text/g2pw/onnx_api.py | 91 ++- 10 files changed, 1417 insertions(+), 149 deletions(-) create mode 100644 GPT_SoVITS/text/g2pw/cuda_api.py create mode 100644 GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 92c829a1..81c1ca1e 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -529,6 +529,7 @@ class TTS: self.bert_model, self.bert_tokenizer, self.configs.device, + version=self.configs.version, bert_stage_limiter=self.prepare_bert_stage_limiter, bert_batch_worker=self.prepare_bert_batch_worker, ) @@ -558,6 +559,16 @@ class TTS: return None def snapshot_prepare_runtime_components(self) -> dict: + g2pw_runtime = None + try: + from text import chinese2 + + g2pw_instance = getattr(chinese2, "g2pw", None) + g2pw_backend = None if g2pw_instance is None else getattr(g2pw_instance, "_g2pw", None) + if g2pw_backend is not None and hasattr(g2pw_backend, "snapshot"): + g2pw_runtime = dict(g2pw_backend.snapshot()) + except Exception: + g2pw_runtime = None return { "text_cpu": { "workers": int(self.prepare_text_cpu_workers), @@ -587,6 +598,7 @@ class TTS: "text_preprocessor": ( None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot() ), + "g2pw": g2pw_runtime, } def _build_text_cpu_admission_state(self) -> dict: @@ -1204,6 +1216,9 @@ class TTS: def prepare_text_segments(self, text: str, language: str): return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version) + def resolve_g2pw_segments(self, prepared_segments, profile: dict | None = None): + return self.text_preprocessor.resolve_g2pw_segments(prepared_segments, profile=profile) + def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None): return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile) diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 01a8ea4d..c30e3195 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -101,6 +101,7 @@ class PreparedTextSegment: phones: List[int] word2ph: Optional[List[int]] norm_text: str + needs_g2pw: bool = False class TextPreprocessor: @@ -109,12 +110,14 @@ class TextPreprocessor: bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device, + version: str = "v2", bert_stage_limiter: StageLimiter | None = None, bert_batch_worker: PrepareBertBatchWorker | None = None, ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device + self.version = str(version) self.bert_stage_limiter = bert_stage_limiter self.bert_batch_worker = bert_batch_worker @@ -261,15 +264,66 @@ class TextPreprocessor: phones=list(payload["phones"]), word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]), norm_text=str(payload["norm_text"]), + needs_g2pw=bool(payload.get("needs_g2pw", False)), ) for payload in payloads ] + def resolve_g2pw_segments( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> List[PreparedTextSegment]: + zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)] + if not zh_indices: + return prepared_segments + from text import chinese2 + + normalized_segments = [prepared_segments[index].norm_text for index in zh_indices] + resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True) + self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0)) + self._accumulate_profile( + profile, + "g2pw_runtime_collect_wait_ms", + g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0), + ) + self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0)) + self._update_profile_peak( + profile, + "g2pw_runtime_batch_rows_peak", + g2pw_profile.get("g2pw_runtime_batch_rows", 0.0), + ) + self._update_profile_peak( + profile, + "g2pw_runtime_batch_requests_peak", + g2pw_profile.get("g2pw_runtime_batch_requests", 0.0), + ) + self._update_profile_peak( + profile, + "g2pw_runtime_pool_workers", + g2pw_profile.get("g2pw_runtime_pool_workers", 0.0), + ) + for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments): + prepared_segments[index] = PreparedTextSegment( + language=prepared_segments[index].language, + phones=list(cleaned_text_to_sequence(phones, self.version)), + word2ph=None if word2ph is None else list(word2ph), + norm_text=str(norm_text), + needs_g2pw=False, + ) + return prepared_segments + def build_phones_and_bert_from_segments( self, prepared_segments: List[PreparedTextSegment], profile: Dict | None = None, ) -> Tuple[list, torch.Tensor, str]: + prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile) phones_list: List[List[int]] = [] bert_list: List[torch.Tensor] = [] norm_text_list: List[str] = [] @@ -402,6 +456,7 @@ class TextPreprocessor: prepared_segments: List[PreparedTextSegment], profile: Dict | None = None, ) -> Tuple[list, torch.Tensor, str]: + prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile) segment_jobs = self._build_async_segment_jobs(prepared_segments, profile) pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] for segment_index, segment in enumerate(prepared_segments): @@ -473,6 +528,8 @@ class TextPreprocessor: prompt_profile: Dict | None = None, target_profile: Dict | None = None, ) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]: + prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile) + target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile) prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile) target_jobs = self._build_async_segment_jobs(target_segments, target_profile) pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 71134268..65bcbb51 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -120,6 +120,15 @@ class PrepareCoordinator: max_workers=self.text_feature_workers, thread_name_prefix="prepare-text-feature", ) + g2pw_default_workers = max(8, int(getattr(tts, "prepare_text_cpu_workers", 8) or 8)) + self.g2pw_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_G2PW_WORKERS", str(g2pw_default_workers))), + ) + self.g2pw_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.g2pw_workers, + thread_name_prefix="prepare-g2pw", + ) ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.ref_audio_workers = max( 1, @@ -130,12 +139,17 @@ class PrepareCoordinator: thread_name_prefix="prepare-ref-audio", ) text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0)) + g2pw_gate_default = max(0, int(self.g2pw_workers)) text_feature_gate_default = max(0, int(self.text_feature_workers)) ref_audio_gate_default = max(0, int(self.ref_audio_workers)) self.text_cpu_gate = AsyncStageGate( int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))), poll_ms=gate_poll_ms, ) + self.g2pw_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_G2PW_MAX_INFLIGHT", str(g2pw_gate_default))), + poll_ms=gate_poll_ms, + ) self.text_feature_gate = AsyncStageGate( int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))), poll_ms=gate_poll_ms, @@ -172,6 +186,7 @@ class PrepareCoordinator: "peak_inflight": int(self.peak_inflight), "max_inflight": int(self.max_inflight), "text_feature_workers": int(self.text_feature_workers), + "g2pw_workers": int(self.g2pw_workers), "ref_audio_workers": int(self.ref_audio_workers), } runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None) @@ -182,6 +197,7 @@ class PrepareCoordinator: snapshot["prepare_runtime_state"] = None snapshot["prepare_stage_gates"] = { "text_cpu": self.text_cpu_gate.snapshot(), + "g2pw": self.g2pw_gate.snapshot(), "text_feature": self.text_feature_gate.snapshot(), "ref_audio": self.ref_audio_gate.snapshot(), "ref_load": self.ref_load_gate.snapshot(), @@ -204,6 +220,11 @@ class PrepareCoordinator: def _prepare_text_cpu(self, text: str, language: str): return self.tts.prepare_text_segments(text, language) + def _resolve_g2pw_segments(self, prepared_segments): + profile: Dict[str, float] = {} + resolved_segments = self.tts.resolve_g2pw_segments(prepared_segments, profile=profile) + return resolved_segments, profile + def _load_ref_audio_raw(self, ref_audio_path: str): return self.tts._load_ref_audio_raw(ref_audio_path) @@ -225,8 +246,15 @@ class PrepareCoordinator: dtype=(dtype if dtype is not None else None) or __import__("torch").float32, ) - def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures: - profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)} + def _build_text_features( + self, + prepared_segments, + language: str, + cpu_run_ms: float, + base_profile: Dict[str, float] | None = None, + ) -> PreparedTextFeatures: + profile: Dict[str, float] = dict(base_profile or {}) + profile["cpu_preprocess_ms"] = float(cpu_run_ms) branch_start = time.perf_counter() phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile) total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0) @@ -291,10 +319,53 @@ class PrepareCoordinator: prepared_segments, language, cpu_run_ms, + None, ) finally: self.text_feature_gate.release() + async def _run_g2pw_stage(self, prepared_segments) -> ProfiledResult: + has_pending = any(bool(getattr(segment, "needs_g2pw", False)) for segment in (prepared_segments or [])) + if not has_pending: + submit_at = time.perf_counter() + return ProfiledResult( + result=prepared_segments, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + profile={}, + ) + await self.g2pw_gate.acquire() + try: + profiled = await self._run_on_executor(self.g2pw_executor, self._resolve_g2pw_segments, prepared_segments) + result, stage_profile = profiled.result + return ProfiledResult( + result=result, + submit_at=float(profiled.submit_at), + started_at=float(profiled.started_at), + finished_at=float(profiled.finished_at), + profile=dict(stage_profile), + ) + finally: + self.g2pw_gate.release() + + async def _run_g2pw_pair_stage(self, prompt_segments, target_segments) -> tuple[ProfiledResult, ProfiledResult]: + prompt_is_empty = len(prompt_segments or []) == 0 + target_task = asyncio.create_task(self._run_g2pw_stage(target_segments)) + if not prompt_is_empty: + prompt_task = asyncio.create_task(self._run_g2pw_stage(prompt_segments)) + return await asyncio.gather(prompt_task, target_task) + target_profiled = await target_task + submit_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_segments, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + profile={}, + ) + return prompt_profiled, target_profiled + @staticmethod def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: return float( @@ -310,12 +381,32 @@ class PrepareCoordinator: target_segments, prompt_cpu_run_ms: float, target_cpu_run_ms: float, + prompt_base_profile: Dict[str, float] | None = None, + target_base_profile: Dict[str, float] | None = None, ) -> tuple[ProfiledResult, ProfiledResult]: prompt_is_empty = len(prompt_segments or []) == 0 if self.text_feature_executor is not None: - target_feature_task = asyncio.create_task(self._run_text_feature_stage(target_segments, None, target_cpu_run_ms)) + target_feature_task = asyncio.create_task( + self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + target_segments, + None, + target_cpu_run_ms, + target_base_profile, + ) + ) if not prompt_is_empty: - prompt_feature_task = asyncio.create_task(self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms)) + prompt_feature_task = asyncio.create_task( + self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + prompt_segments, + None, + prompt_cpu_run_ms, + prompt_base_profile, + ) + ) return await asyncio.gather(prompt_feature_task, target_feature_task) target_profiled = await target_feature_task submit_at = time.perf_counter() @@ -328,7 +419,8 @@ class PrepareCoordinator: return prompt_profiled, target_profiled await self.text_feature_gate.acquire() - target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} + target_profile: Dict[str, float] = dict(target_base_profile or {}) + target_profile["cpu_preprocess_ms"] = float(target_cpu_run_ms) submit_at = time.perf_counter() started_at = float(submit_at) try: @@ -377,7 +469,8 @@ class PrepareCoordinator: target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) return prompt_profiled, target_profiled - prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + prompt_profile: Dict[str, float] = dict(prompt_base_profile or {}) + prompt_profile["cpu_preprocess_ms"] = float(prompt_cpu_run_ms) prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( prompt_segments, target_segments, @@ -589,20 +682,31 @@ class PrepareCoordinator: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: try: - text_pair_start = time.perf_counter() - ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) - text_feature_pair_task = asyncio.create_task( - self._run_text_feature_pair_stage( + g2pw_pair_start = time.perf_counter() + g2pw_pair_task = asyncio.create_task( + self._run_g2pw_pair_stage( cpu_stage.prompt_cpu_profiled.result, cpu_stage.target_cpu_profiled.result, - cpu_stage.prompt_cpu_profiled.run_ms, - cpu_stage.target_cpu_profiled.run_ms, ) ) - (prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather( - text_feature_pair_task, + ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) + (prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather( + g2pw_pair_task, ref_audio_task, ) + g2pw_pair_end = time.perf_counter() + text_pair_start = time.perf_counter() + text_feature_pair_task = asyncio.create_task( + self._run_text_feature_pair_stage( + prompt_g2pw_profiled.result, + target_g2pw_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, + prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), + target_base_profile=dict(target_g2pw_profiled.profile or {}), + ) + ) + prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task text_pair_end = time.perf_counter() state = build_request_state_from_parts( tts=self.tts, @@ -619,6 +723,17 @@ class PrepareCoordinator: "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), + "g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0), + "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, + "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, + "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, + "text_g2pw_run_ms": target_g2pw_profiled.run_ms, + "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), "prompt_text_parallel_future_wait_ms": 0.0, "prompt_text_parallel_future_executor_queue_ms": 0.0, "prompt_text_parallel_future_run_ms": 0.0, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index ed465f69..e993a1ef 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -241,6 +241,24 @@ def build_request_state_from_parts( prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0) ), "prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)), + "prompt_text_g2pw_total_ms": float(prompt_result.profile.get("g2pw_total_ms", 0.0)), + "prompt_text_g2pw_prepare_ms": float(prompt_result.profile.get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float(prompt_result.profile.get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float(prompt_result.profile.get("g2pw_post_ms", 0.0)), + "prompt_text_g2pw_runtime_total_ms": float(prompt_result.profile.get("g2pw_runtime_total_ms", 0.0)), + "prompt_text_g2pw_runtime_queue_wait_ms": float( + prompt_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0) + ), + "prompt_text_g2pw_runtime_collect_wait_ms": float( + prompt_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0) + ), + "prompt_text_g2pw_runtime_run_ms": float(prompt_result.profile.get("g2pw_runtime_run_ms", 0.0)), + "prompt_text_g2pw_runtime_batch_rows_peak": float( + prompt_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0) + ), + "prompt_text_g2pw_runtime_batch_requests_peak": float( + prompt_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0) + ), "prompt_text_parallel_future_wait_ms": 0.0, "prompt_text_parallel_future_executor_queue_ms": 0.0, "prompt_text_parallel_future_run_ms": float(prompt_result.total_ms), @@ -267,6 +285,18 @@ def build_request_state_from_parts( ), "text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)), "text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)), + "text_g2pw_total_ms": float(target_result.profile.get("g2pw_total_ms", 0.0)), + "text_g2pw_prepare_ms": float(target_result.profile.get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float(target_result.profile.get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float(target_result.profile.get("g2pw_post_ms", 0.0)), + "text_g2pw_runtime_total_ms": float(target_result.profile.get("g2pw_runtime_total_ms", 0.0)), + "text_g2pw_runtime_queue_wait_ms": float(target_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0)), + "text_g2pw_runtime_collect_wait_ms": float(target_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0)), + "text_g2pw_runtime_run_ms": float(target_result.profile.get("g2pw_runtime_run_ms", 0.0)), + "text_g2pw_runtime_batch_rows_peak": float(target_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0)), + "text_g2pw_runtime_batch_requests_peak": float( + target_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0) + ), "text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)), "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), "audio_load_ms": audio_load_ms, diff --git a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py index e2398251..3d5b2de5 100644 --- a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py +++ b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py @@ -8,6 +8,7 @@ sys.path.append(now_dir) from text.LangSegmenter import LangSegmenter from text import cleaned_text_to_sequence +from text import chinese2 from text.cleaner import clean_text @@ -83,16 +84,27 @@ def preprocess_text_segments_payload( payloads: List[PreparedTextSegmentPayload] = [] total_phones_len = 0 for segment_text, segment_lang in zip(textlist, langlist): - phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + normalized_language = segment_lang.replace("all_", "") + if normalized_language == "zh": + norm_text = chinese2.text_normalize(segment_text) + phones = [] + word2ph = None + needs_g2pw = True + estimated_phones_len = max(0, len(norm_text) * 2) + else: + phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + needs_g2pw = False + estimated_phones_len = len(phones) payloads.append( { - "language": segment_lang.replace("all_", ""), + "language": normalized_language, "phones": phones, "word2ph": word2ph, "norm_text": norm_text, + "needs_g2pw": needs_g2pw, } ) - total_phones_len += len(phones) + total_phones_len += int(estimated_phones_len) if not final and total_phones_len < 6: return preprocess_text_segments_payload("." + text, language, version, final=True) diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index acfebfe2..a5d32490 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -1,5 +1,6 @@ import os import re +import time import cn2an from pypinyin import lazy_pinyin, Style @@ -77,6 +78,205 @@ def g2p(text): return phones, word2ph +def _prepare_g2p_segments(segments): + prepared_segments = [] + batch_inputs = [] + for segment in segments: + processed_segment = re.sub("[a-zA-Z]+", "", segment) + seg_cut = psg.lcut(processed_segment) + seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) + prepared_segments.append( + { + "segment": processed_segment, + "seg_cut": seg_cut, + } + ) + if processed_segment: + batch_inputs.append(processed_segment) + return prepared_segments, batch_inputs + + +def _build_segment_from_g2pw(segment: str, seg_cut, pinyins): + phones_list = [] + word2ph = [] + initials = [] + finals = [] + pre_word_length = 0 + for word, pos in seg_cut: + sub_initials = [] + sub_finals = [] + now_word_length = pre_word_length + len(word) + + if pos == "eng": + pre_word_length = now_word_length + continue + + word_pinyins = pinyins[pre_word_length:now_word_length] + word_pinyins = correct_pronunciation(word, word_pinyins) + + for pinyin in word_pinyins: + if pinyin[0].isalpha(): + sub_initials.append(to_initials(pinyin)) + sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True)) + else: + sub_initials.append(pinyin) + sub_finals.append(pinyin) + + pre_word_length = now_word_length + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + + initials = sum(initials, []) + finals = sum(finals, []) + for c, v in zip(initials, finals): + raw_pinyin = c + v + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + + pinyin = c + v_without_tone + assert tone in "12345" + + if c: + v_rep_map = { + "uei": "ui", + "iou": "iu", + "uen": "un", + } + if v_without_tone in v_rep_map.keys(): + pinyin = c + v_rep_map[v_without_tone] + else: + pinyin_rep_map = { + "ing": "ying", + "i": "yi", + "in": "yin", + "u": "wu", + } + if pinyin in pinyin_rep_map.keys(): + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = { + "v": "yu", + "e": "e", + "i": "y", + "u": "w", + } + if pinyin[0] in single_rep_map.keys(): + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + + phones_list += phone + return phones_list, word2ph + + +def _build_segment_without_g2pw(segment: str, seg_cut): + initials = [] + finals = [] + for word, pos in seg_cut: + if pos == "eng": + continue + sub_initials, sub_finals = _get_initials_finals(word) + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + phones_list = [] + word2ph = [] + for c, v in zip(sum(initials, []), sum(finals, [])): + raw_pinyin = c + v + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + pinyin = c + v_without_tone + assert tone in "12345" + if c: + v_rep_map = {"uei": "ui", "iou": "iu", "uen": "un"} + if v_without_tone in v_rep_map: + pinyin = c + v_rep_map[v_without_tone] + else: + pinyin_rep_map = {"ing": "ying", "i": "yi", "in": "yin", "u": "wu"} + if pinyin in pinyin_rep_map: + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = {"v": "yu", "e": "e", "i": "y", "u": "w"} + if pinyin[0] in single_rep_map: + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + phones_list += phone + return phones_list, word2ph + + +def g2p_segments(segments, return_profile: bool = False): + prepare_start = time.perf_counter() + prepared_segments, batch_inputs = _prepare_g2p_segments(segments) + profile = { + "g2pw_prepare_ms": 0.0, + "g2pw_predict_ms": 0.0, + "g2pw_post_ms": 0.0, + "g2pw_runtime_total_ms": 0.0, + "g2pw_runtime_queue_wait_ms": 0.0, + "g2pw_runtime_collect_wait_ms": 0.0, + "g2pw_runtime_run_ms": 0.0, + "g2pw_runtime_batch_rows": 0.0, + "g2pw_runtime_batch_requests": 0.0, + "g2pw_runtime_pool_workers": 0.0, + "g2pw_runtime_shard_index": 0.0, + } + profile["g2pw_prepare_ms"] = float((time.perf_counter() - prepare_start) * 1000.0) + if is_g2pw and batch_inputs: + converter = g2pw._g2pw + if hasattr(converter, "predict_sentences_with_profile"): + g2pw_batch_results, predict_profile = converter.predict_sentences_with_profile(batch_inputs) + for key, value in dict(predict_profile or {}).items(): + profile[key] = float(value) + else: + predict_start = time.perf_counter() + g2pw_batch_results = converter(batch_inputs) + profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_start) * 1000.0) + else: + g2pw_batch_results = [] + post_start = time.perf_counter() + results = [] + batch_cursor = 0 + for item in prepared_segments: + segment = item["segment"] + if not segment: + results.append(([], [], segment)) + continue + if not is_g2pw: + phones, word2ph = _build_segment_without_g2pw(segment, item["seg_cut"]) + results.append((phones, word2ph, segment)) + continue + pinyins = g2pw_batch_results[batch_cursor] + batch_cursor += 1 + phones, word2ph = _build_segment_from_g2pw(segment, item["seg_cut"], pinyins) + results.append((phones, word2ph, segment)) + profile["g2pw_post_ms"] = float((time.perf_counter() - post_start) * 1000.0) + profile["g2pw_total_ms"] = float(profile["g2pw_prepare_ms"] + profile["g2pw_predict_ms"] + profile["g2pw_post_ms"]) + if return_profile: + return results, profile + return results + + def _get_initials_finals(word): initials = [] finals = [] @@ -180,125 +380,9 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> def _g2p(segments): phones_list = [] word2ph = [] - g2pw_batch_results = [] - g2pw_batch_cursor = 0 - processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments] - if is_g2pw: - batch_inputs = [seg for seg in processed_segments if seg] - g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else [] - - for seg in processed_segments: - pinyins = [] - seg_cut = psg.lcut(seg) - seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) - initials = [] - finals = [] - - if not is_g2pw: - for word, pos in seg_cut: - if pos == "eng": - continue - sub_initials, sub_finals = _get_initials_finals(word) - sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) - # 儿化 - sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) - initials.append(sub_initials) - finals.append(sub_finals) - # assert len(sub_initials) == len(sub_finals) == len(word) - initials = sum(initials, []) - finals = sum(finals, []) - print("pypinyin结果", initials, finals) - else: - # g2pw采用整句推理(批量推理,逐句取结果) - if seg: - pinyins = g2pw_batch_results[g2pw_batch_cursor] - g2pw_batch_cursor += 1 - - pre_word_length = 0 - for word, pos in seg_cut: - sub_initials = [] - sub_finals = [] - now_word_length = pre_word_length + len(word) - - if pos == "eng": - pre_word_length = now_word_length - continue - - word_pinyins = pinyins[pre_word_length:now_word_length] - - # 多音字消歧 - word_pinyins = correct_pronunciation(word, word_pinyins) - - for pinyin in word_pinyins: - if pinyin[0].isalpha(): - sub_initials.append(to_initials(pinyin)) - sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True)) - else: - sub_initials.append(pinyin) - sub_finals.append(pinyin) - - pre_word_length = now_word_length - sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) - # 儿化 - sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) - initials.append(sub_initials) - finals.append(sub_finals) - - initials = sum(initials, []) - finals = sum(finals, []) - # print("g2pw结果",initials,finals) - - for c, v in zip(initials, finals): - raw_pinyin = c + v - # NOTE: post process for pypinyin outputs - # we discriminate i, ii and iii - if c == v: - assert c in punctuation - phone = [c] - word2ph.append(1) - else: - v_without_tone = v[:-1] - tone = v[-1] - - pinyin = c + v_without_tone - assert tone in "12345" - - if c: - # 多音节 - v_rep_map = { - "uei": "ui", - "iou": "iu", - "uen": "un", - } - if v_without_tone in v_rep_map.keys(): - pinyin = c + v_rep_map[v_without_tone] - else: - # 单音节 - pinyin_rep_map = { - "ing": "ying", - "i": "yi", - "in": "yin", - "u": "wu", - } - if pinyin in pinyin_rep_map.keys(): - pinyin = pinyin_rep_map[pinyin] - else: - single_rep_map = { - "v": "yu", - "e": "e", - "i": "y", - "u": "w", - } - if pinyin[0] in single_rep_map.keys(): - pinyin = single_rep_map[pinyin[0]] + pinyin[1:] - - assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) - new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") - new_v = new_v + tone - phone = [new_c, new_v] - word2ph.append(len(phone)) - - phones_list += phone + for phones, item_word2ph, _segment in g2p_segments(segments): + phones_list += phones + word2ph += item_word2ph return phones_list, word2ph diff --git a/GPT_SoVITS/text/g2pw/cuda_api.py b/GPT_SoVITS/text/g2pw/cuda_api.py new file mode 100644 index 00000000..e1a84748 --- /dev/null +++ b/GPT_SoVITS/text/g2pw/cuda_api.py @@ -0,0 +1,670 @@ +import ctypes +import fcntl +import os +import subprocess +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Deque, Dict, List, Tuple + +import numpy as np + +from .onnx_api import _G2PWBaseOnnxConverter + + +class G2PWCudaError(RuntimeError): + pass + + +@dataclass +class G2PWBatchTask: + model_input: Dict[str, np.ndarray] + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + done_event: threading.Event = field(default_factory=threading.Event) + output: np.ndarray | None = None + profile: Dict[str, float] = field(default_factory=dict) + error: Exception | None = None + + +_ROOT_DIR = Path(__file__).resolve().parents[3] +_PACKAGE_DIR = Path(__file__).resolve().parent +_OUTPUT_DIR = _ROOT_DIR / "outputs" / "g2pw_cuda_bridge" +_WRAPPER_SOURCE = _PACKAGE_DIR / "g2pw_cuda_bridge.cpp" +_LOCK_PATH = _OUTPUT_DIR / "build.lock" + + +def _env_flag(name: str, default: bool) -> int: + raw = os.environ.get(name) + if raw is None: + return 1 if default else 0 + return 0 if raw.strip().lower() in {"0", "false", "no", "off"} else 1 + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name) + if raw is None or raw.strip() == "": + return int(default) + return int(raw) + + +def _resolve_cuda_root() -> Path: + env_root = os.environ.get("GPTSOVITS_G2PW_CUDA_ROOT", "").strip() + candidates = [ + env_root, + _ROOT_DIR / "third_party" / "g2pw-cu", + ] + for candidate in candidates: + if not candidate: + continue + path = Path(candidate).expanduser().resolve() + if path.exists(): + return path + checked = [ + str(Path(candidate).expanduser().resolve()) + for candidate in candidates + if str(candidate).strip() != "" + ] + raise G2PWCudaError( + "Cannot locate g2pw-cu root. " + "Expected one of: " + f"{checked}. " + "Recommended: clone https://github.com/baicai-1145/g2pw-cu.git into " + f"{(_ROOT_DIR / 'third_party' / 'g2pw-cu').as_posix()} " + "or set GPTSOVITS_G2PW_CUDA_ROOT explicitly." + ) + + +def _resolve_runtime_paths() -> tuple[Path, Path, Path]: + cuda_root = _resolve_cuda_root() + runtime_lib = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_RUNTIME_LIB", str(cuda_root / "build" / "libg2pw_runtime.so")) + ).expanduser() + manifest_path = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_MANIFEST", str(cuda_root / "artifacts" / "model" / "manifest.txt")) + ).expanduser() + weights_path = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_WEIGHTS", str(cuda_root / "artifacts" / "model" / "weights.bin")) + ).expanduser() + for path in (runtime_lib, manifest_path, weights_path): + if not path.exists(): + raise G2PWCudaError(f"Missing g2pw-cu artifact: {path}") + return runtime_lib.resolve(), manifest_path.resolve(), weights_path.resolve() + + +def _build_bridge(wrapper_output: Path, runtime_lib: Path) -> None: + _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + compile_cmd = [ + os.environ.get("CXX", "g++"), + "-O3", + "-std=c++17", + "-shared", + "-fPIC", + str(_WRAPPER_SOURCE), + "-I", + str(runtime_lib.parent.parent / "include"), + "-L", + str(runtime_lib.parent), + "-lg2pw_runtime", + f"-Wl,-rpath,{runtime_lib.parent}", + "-o", + str(wrapper_output), + ] + result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False) + if result.returncode != 0: + raise G2PWCudaError( + "Failed to build g2pw-cu bridge:\n" + f"cmd={' '.join(compile_cmd)}\n" + f"stdout={result.stdout}\n" + f"stderr={result.stderr}" + ) + + +def _ensure_bridge_built(runtime_lib: Path) -> Path: + wrapper_output = _OUTPUT_DIR / "g2pw_cuda_bridge.so" + _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + with _LOCK_PATH.open("w", encoding="utf-8") as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + needs_build = not wrapper_output.exists() + if not needs_build: + so_mtime = wrapper_output.stat().st_mtime + needs_build = so_mtime < _WRAPPER_SOURCE.stat().st_mtime or so_mtime < runtime_lib.stat().st_mtime + if needs_build: + tmp_output = wrapper_output.with_suffix(".tmp.so") + if tmp_output.exists(): + tmp_output.unlink() + _build_bridge(tmp_output, runtime_lib) + tmp_output.replace(wrapper_output) + return wrapper_output + + +def _load_bridge(): + runtime_lib, manifest_path, weights_path = _resolve_runtime_paths() + bridge_path = _ensure_bridge_built(runtime_lib) + global_mode = getattr(ctypes, "RTLD_GLOBAL", getattr(os, "RTLD_GLOBAL", 0)) + ctypes.CDLL(str(runtime_lib), mode=global_mode) + lib = ctypes.CDLL(str(bridge_path)) + lib.g2pw_runtime_create.argtypes = [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + lib.g2pw_runtime_create.restype = ctypes.c_void_p + lib.g2pw_runtime_destroy.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_destroy.restype = None + lib.g2pw_runtime_last_error.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_last_error.restype = ctypes.c_char_p + lib.g2pw_runtime_num_labels.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_num_labels.restype = ctypes.c_int + lib.g2pw_runtime_run.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_void_p, + ] + lib.g2pw_runtime_run.restype = ctypes.c_int + return lib, manifest_path, weights_path, runtime_lib + + +def _gemm_precision_value() -> int: + precision = os.environ.get("GPTSOVITS_G2PW_CUDA_GEMM_PRECISION", "fp32").strip().lower() + if precision == "fp16": + return 1 + if precision == "bf16": + return 2 + return 0 + + +class G2PWRuntimeWrapper: + def __init__(self, shard_index: int = 0) -> None: + self.lib, self.manifest_path, self.weights_path, self.runtime_lib = _load_bridge() + self.shard_index = int(shard_index) + self.device_ordinal = _env_int("GPTSOVITS_G2PW_CUDA_DEVICE", 0) + self.allow_tensor_cores = _env_flag("GPTSOVITS_G2PW_CUDA_ALLOW_TENSOR_CORES", False) + self.use_cublaslt_bias_epilogue = _env_flag("GPTSOVITS_G2PW_CUDA_USE_CUBLASLT_BIAS_EPILOGUE", False) + self.enable_profiling = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_PROFILE", False) + self.enable_cuda_graph = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_GRAPH", True) + self.dump_graph_cache_stats = _env_flag("GPTSOVITS_G2PW_CUDA_DUMP_GRAPH_CACHE_STATS", False) + self.full_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_FULL_GRAPH_CACHE_LIMIT", 0) + self.tail_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_TAIL_GRAPH_CACHE_LIMIT", 0) + self.gemm_precision = _gemm_precision_value() + self.lock = threading.Lock() + self.handle = None + self.max_batch_size = 0 + self.max_seq_len = 0 + self.num_labels = 0 + self.batch_enabled = _env_flag("GPTSOVITS_G2PW_CUDA_BATCHING", True) != 0 + self.batch_window_s = max(0.0, float(_env_int("GPTSOVITS_G2PW_CUDA_BATCH_WINDOW_MS", 1)) / 1000.0) + self.batch_max_requests = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_REQUESTS", 64)) + self.batch_max_rows = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_ROWS", 96)) + self.batch_max_tokens = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_TOKENS", 4096)) + self.batch_condition = threading.Condition() + self.pending_tasks: Deque[G2PWBatchTask] = deque() + self.batch_total_tasks = 0 + self.batch_total_batches = 0 + self.batch_total_rows = 0 + self.batch_total_queue_wait_ms = 0.0 + self.batch_queue_wait_peak_ms = 0.0 + self.batch_total_collect_wait_ms = 0.0 + self.batch_collect_wait_peak_ms = 0.0 + self.batch_total_run_ms = 0.0 + self.batch_run_peak_ms = 0.0 + self.batch_rows_peak = 0 + self.batch_requests_peak = 0 + self.batch_pending_peak = 0 + self.closed = False + self._ensure_capacity( + batch_size=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_BATCH_SIZE", 256)), + seq_len=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_SEQ_LEN", 128)), + ) + self.batch_worker = None + if self.batch_enabled: + self.batch_worker = threading.Thread( + target=self._batch_loop, + name=f"g2pw-cuda-batch-worker-{self.shard_index}", + daemon=True, + ) + self.batch_worker.start() + + def _destroy_handle(self) -> None: + if self.handle: + self.lib.g2pw_runtime_destroy(self.handle) + self.handle = None + + def close(self) -> None: + with self.batch_condition: + self.closed = True + self.batch_condition.notify_all() + self._destroy_handle() + + def __del__(self): + try: + self.close() + except Exception: + pass + + def _last_error(self) -> str: + if not self.handle: + return "uninitialized runtime" + message = self.lib.g2pw_runtime_last_error(self.handle) + return "" if not message else message.decode("utf-8", errors="replace") + + def _create_handle(self, batch_size: int, seq_len: int) -> None: + new_handle = self.lib.g2pw_runtime_create( + str(self.manifest_path).encode("utf-8"), + str(self.weights_path).encode("utf-8"), + int(self.device_ordinal), + int(batch_size), + int(seq_len), + int(self.full_graph_cache_limit), + int(self.tail_graph_cache_limit), + int(self.allow_tensor_cores), + int(self.use_cublaslt_bias_epilogue), + int(self.enable_profiling), + int(self.enable_cuda_graph), + int(self.dump_graph_cache_stats), + int(self.gemm_precision), + ) + if not new_handle: + raise G2PWCudaError("g2pw-cu returned null runtime handle") + self.handle = new_handle + self.max_batch_size = int(batch_size) + self.max_seq_len = int(seq_len) + self.num_labels = int(self.lib.g2pw_runtime_num_labels(self.handle)) + last_error = self._last_error() + if self.num_labels <= 0 or last_error: + self.close() + raise G2PWCudaError(f"Failed to initialize g2pw-cu runtime: {last_error or 'num_labels <= 0'}") + + def _ensure_capacity(self, batch_size: int, seq_len: int) -> None: + target_batch = max(1, int(batch_size)) + target_seq = max(1, int(seq_len)) + if self.handle and target_batch <= self.max_batch_size and target_seq <= self.max_seq_len: + return + next_batch = max(target_batch, self.max_batch_size * 2 if self.max_batch_size else 0) + next_seq = max(target_seq, self.max_seq_len * 2 if self.max_seq_len else 0) + self._destroy_handle() + self._create_handle(batch_size=next_batch, seq_len=next_seq) + + @staticmethod + def _normalize_model_input(model_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + input_ids = np.ascontiguousarray(model_input["input_ids"], dtype=np.int64) + token_type_ids = np.ascontiguousarray(model_input["token_type_ids"], dtype=np.int64) + attention_masks = np.ascontiguousarray(model_input["attention_masks"], dtype=np.int64) + phoneme_masks = np.ascontiguousarray(model_input["phoneme_masks"], dtype=np.float32) + char_ids = np.ascontiguousarray(model_input["char_ids"], dtype=np.int64) + position_ids = np.ascontiguousarray(model_input["position_ids"], dtype=np.int64) + batch_size = int(char_ids.shape[0]) + if input_ids.shape[0] == 1 and batch_size > 1: + input_ids = np.ascontiguousarray(np.repeat(input_ids, batch_size, axis=0), dtype=np.int64) + token_type_ids = np.ascontiguousarray(np.repeat(token_type_ids, batch_size, axis=0), dtype=np.int64) + attention_masks = np.ascontiguousarray(np.repeat(attention_masks, batch_size, axis=0), dtype=np.int64) + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_masks": attention_masks, + "phoneme_masks": phoneme_masks, + "char_ids": char_ids, + "position_ids": position_ids, + } + + def _run_direct(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + normalized = self._normalize_model_input(model_input) + input_ids = normalized["input_ids"] + token_type_ids = normalized["token_type_ids"] + attention_masks = normalized["attention_masks"] + phoneme_masks = normalized["phoneme_masks"] + char_ids = normalized["char_ids"] + position_ids = normalized["position_ids"] + batch_size = int(char_ids.shape[0]) + seq_len = int(input_ids.shape[1]) + probs = np.empty((batch_size, self.num_labels), dtype=np.float32) + with self.lock: + self._ensure_capacity(batch_size=batch_size, seq_len=seq_len) + status = self.lib.g2pw_runtime_run( + self.handle, + input_ids.ctypes.data_as(ctypes.c_void_p), + token_type_ids.ctypes.data_as(ctypes.c_void_p), + attention_masks.ctypes.data_as(ctypes.c_void_p), + phoneme_masks.ctypes.data_as(ctypes.c_void_p), + char_ids.ctypes.data_as(ctypes.c_void_p), + position_ids.ctypes.data_as(ctypes.c_void_p), + batch_size, + seq_len, + probs.ctypes.data_as(ctypes.c_void_p), + ) + if int(status) != 0: + raise G2PWCudaError(f"g2pw-cu inference failed: {self._last_error()}") + return probs + + def _can_append_task(self, tasks: List[G2PWBatchTask], candidate: G2PWBatchTask) -> bool: + request_count = len(tasks) + 1 + if request_count > self.batch_max_requests: + return False + total_rows = sum(int(item.model_input["char_ids"].shape[0]) for item in tasks) + int( + candidate.model_input["char_ids"].shape[0] + ) + if total_rows > self.batch_max_rows: + return False + total_tokens = sum( + int(item.model_input["char_ids"].shape[0]) * int(item.model_input["input_ids"].shape[1]) for item in tasks + ) + int(candidate.model_input["char_ids"].shape[0]) * int(candidate.model_input["input_ids"].shape[1]) + return total_tokens <= self.batch_max_tokens + + def _merge_batch_inputs(self, tasks: List[G2PWBatchTask]) -> Tuple[Dict[str, np.ndarray], List[Tuple[int, int]]]: + normalized_inputs = [self._normalize_model_input(task.model_input) for task in tasks] + total_rows = sum(int(item["char_ids"].shape[0]) for item in normalized_inputs) + max_seq_len = max(int(item["input_ids"].shape[1]) for item in normalized_inputs) + input_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64) + token_type_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64) + attention_masks = np.zeros((total_rows, max_seq_len), dtype=np.int64) + phoneme_masks = np.zeros((total_rows, normalized_inputs[0]["phoneme_masks"].shape[1]), dtype=np.float32) + char_ids = np.zeros((total_rows,), dtype=np.int64) + position_ids = np.zeros((total_rows,), dtype=np.int64) + slices: List[Tuple[int, int]] = [] + cursor = 0 + for item in normalized_inputs: + rows = int(item["char_ids"].shape[0]) + seq_len = int(item["input_ids"].shape[1]) + next_cursor = cursor + rows + input_ids[cursor:next_cursor, :seq_len] = item["input_ids"] + token_type_ids[cursor:next_cursor, :seq_len] = item["token_type_ids"] + attention_masks[cursor:next_cursor, :seq_len] = item["attention_masks"] + phoneme_masks[cursor:next_cursor] = item["phoneme_masks"] + char_ids[cursor:next_cursor] = item["char_ids"] + position_ids[cursor:next_cursor] = item["position_ids"] + slices.append((cursor, next_cursor)) + cursor = next_cursor + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_masks": attention_masks, + "phoneme_masks": phoneme_masks, + "char_ids": char_ids, + "position_ids": position_ids, + }, slices + + def _finish_task( + self, + task: G2PWBatchTask, + output: np.ndarray | None = None, + profile: Dict[str, float] | None = None, + error: Exception | None = None, + ) -> None: + task.output = output + task.profile = dict(profile or {}) + task.error = error + task.done_event.set() + + def _batch_loop(self) -> None: + while True: + with self.batch_condition: + while not self.pending_tasks and not self.closed: + self.batch_condition.wait() + if self.closed and not self.pending_tasks: + return + first_task = self.pending_tasks.popleft() + batch_tasks = [first_task] + collect_started = time.perf_counter() + deadline = collect_started + self.batch_window_s + while True: + if len(batch_tasks) >= self.batch_max_requests: + break + remaining = deadline - time.perf_counter() + if remaining <= 0.0: + break + if not self.pending_tasks: + self.batch_condition.wait(timeout=remaining) + continue + candidate = self.pending_tasks[0] + if not self._can_append_task(batch_tasks, candidate): + break + batch_tasks.append(self.pending_tasks.popleft()) + collect_wait_ms = max(0.0, (time.perf_counter() - collect_started) * 1000.0) + + now = time.perf_counter() + queue_wait_values = [max(0.0, (now - task.enqueued_at) * 1000.0) for task in batch_tasks] + try: + merged_input, row_slices = self._merge_batch_inputs(batch_tasks) + run_started = time.perf_counter() + merged_output = self._run_direct(merged_input) + run_ms = max(0.0, (time.perf_counter() - run_started) * 1000.0) + for task, (start, end) in zip(batch_tasks, row_slices): + task_rows = int(task.model_input["char_ids"].shape[0]) + task_seq_len = int(task.model_input["input_ids"].shape[1]) + self._finish_task( + task, + output=np.ascontiguousarray(merged_output[start:end]), + profile={ + "g2pw_runtime_queue_wait_ms": float(max(0.0, (run_started - task.enqueued_at) * 1000.0)), + "g2pw_runtime_collect_wait_ms": float(collect_wait_ms), + "g2pw_runtime_run_ms": float(run_ms), + "g2pw_runtime_batch_rows": float(sum(int(item.model_input["char_ids"].shape[0]) for item in batch_tasks)), + "g2pw_runtime_batch_requests": float(len(batch_tasks)), + "g2pw_runtime_task_rows": float(task_rows), + "g2pw_runtime_task_seq_len": float(task_seq_len), + "g2pw_runtime_shard_index": float(self.shard_index), + }, + ) + except Exception as exc: + run_ms = 0.0 + for task in batch_tasks: + self._finish_task(task, error=exc) + finally: + with self.batch_condition: + self.batch_total_batches += 1 + self.batch_total_tasks += len(batch_tasks) + self.batch_total_rows += sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks) + self.batch_total_queue_wait_ms += float(sum(queue_wait_values)) + self.batch_queue_wait_peak_ms = max(self.batch_queue_wait_peak_ms, max(queue_wait_values or [0.0])) + self.batch_total_collect_wait_ms += float(collect_wait_ms) * float(len(batch_tasks)) + self.batch_collect_wait_peak_ms = max(self.batch_collect_wait_peak_ms, float(collect_wait_ms)) + self.batch_total_run_ms += float(run_ms) + self.batch_run_peak_ms = max(self.batch_run_peak_ms, float(run_ms)) + self.batch_rows_peak = max( + self.batch_rows_peak, sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks) + ) + self.batch_requests_peak = max(self.batch_requests_peak, len(batch_tasks)) + + def _submit_batched(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + task = G2PWBatchTask(model_input=model_input) + with self.batch_condition: + if self.closed: + raise G2PWCudaError("g2pw-cu batch worker already closed") + task.enqueued_at = time.perf_counter() + self.pending_tasks.append(task) + self.batch_pending_peak = max(self.batch_pending_peak, len(self.pending_tasks)) + self.batch_condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.output is not None + return task.output, dict(task.profile) + + def snapshot(self) -> Dict[str, float | int | bool]: + with self.batch_condition: + average_tasks_per_batch = ( + float(self.batch_total_tasks) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0 + ) + average_rows_per_batch = ( + float(self.batch_total_rows) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0 + ) + average_queue_wait_ms = ( + float(self.batch_total_queue_wait_ms) / float(self.batch_total_tasks) if self.batch_total_tasks > 0 else 0.0 + ) + average_collect_wait_ms = ( + float(self.batch_total_collect_wait_ms) / float(self.batch_total_tasks) + if self.batch_total_tasks > 0 + else 0.0 + ) + return { + "shard_index": int(self.shard_index), + "enabled": bool(self.batch_enabled), + "window_ms": float(self.batch_window_s * 1000.0), + "max_requests": int(self.batch_max_requests), + "max_rows": int(self.batch_max_rows), + "max_tokens": int(self.batch_max_tokens), + "pending": int(len(self.pending_tasks)), + "pending_peak": int(self.batch_pending_peak), + "total_batches": int(self.batch_total_batches), + "total_tasks": int(self.batch_total_tasks), + "total_rows": int(self.batch_total_rows), + "avg_tasks_per_batch": float(average_tasks_per_batch), + "avg_rows_per_batch": float(average_rows_per_batch), + "avg_queue_wait_ms": float(average_queue_wait_ms), + "queue_wait_peak_ms": float(self.batch_queue_wait_peak_ms), + "avg_collect_wait_ms": float(average_collect_wait_ms), + "collect_wait_peak_ms": float(self.batch_collect_wait_peak_ms), + "run_total_ms": float(self.batch_total_run_ms), + "run_peak_ms": float(self.batch_run_peak_ms), + "batch_rows_peak": int(self.batch_rows_peak), + "batch_requests_peak": int(self.batch_requests_peak), + } + + def pending_rows(self) -> int: + with self.batch_condition: + return int(sum(int(task.model_input["char_ids"].shape[0]) for task in self.pending_tasks)) + + def pending_count(self) -> int: + with self.batch_condition: + return int(len(self.pending_tasks)) + + def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + if not self.batch_enabled: + started = time.perf_counter() + output = self._run_direct(model_input) + return output, { + "g2pw_runtime_queue_wait_ms": 0.0, + "g2pw_runtime_collect_wait_ms": 0.0, + "g2pw_runtime_run_ms": float((time.perf_counter() - started) * 1000.0), + "g2pw_runtime_batch_rows": float(model_input["char_ids"].shape[0]), + "g2pw_runtime_batch_requests": 1.0, + "g2pw_runtime_task_rows": float(model_input["char_ids"].shape[0]), + "g2pw_runtime_task_seq_len": float(model_input["input_ids"].shape[1]), + "g2pw_runtime_shard_index": float(self.shard_index), + } + return self._submit_batched(model_input) + + def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + output, _profile = self.run_with_profile(model_input) + return output + + +class G2PWRuntimePool: + def __init__(self) -> None: + self.worker_count = max(1, _env_int("GPTSOVITS_G2PW_CUDA_WORKERS", 2)) + self.shards = [G2PWRuntimeWrapper(shard_index=index) for index in range(self.worker_count)] + self.lock = threading.Lock() + + def _pick_shard(self) -> G2PWRuntimeWrapper: + with self.lock: + return min( + self.shards, + key=lambda shard: ( + shard.pending_rows(), + shard.pending_count(), + shard.snapshot().get("avg_queue_wait_ms", 0.0), + ), + ) + + def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + shard = self._pick_shard() + output, profile = shard.run_with_profile(model_input) + profile["g2pw_runtime_pool_workers"] = float(self.worker_count) + return output, profile + + def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + output, _profile = self.run_with_profile(model_input) + return output + + def snapshot(self) -> Dict[str, float | int | bool | List[Dict[str, float | int | bool]]]: + shard_snapshots = [dict(shard.snapshot()) for shard in self.shards] + avg_queue_wait_ms = 0.0 + total_tasks = 0.0 + pending = 0 + pending_peak = 0 + total_batches = 0 + total_rows = 0 + batch_rows_peak = 0 + batch_requests_peak = 0 + for snapshot in shard_snapshots: + tasks = float(snapshot.get("total_tasks", 0.0)) + avg_queue_wait_ms += float(snapshot.get("avg_queue_wait_ms", 0.0)) * tasks + total_tasks += tasks + pending += int(snapshot.get("pending", 0)) + pending_peak = max(pending_peak, int(snapshot.get("pending_peak", 0))) + total_batches += int(snapshot.get("total_batches", 0)) + total_rows += int(snapshot.get("total_rows", 0)) + batch_rows_peak = max(batch_rows_peak, int(snapshot.get("batch_rows_peak", 0))) + batch_requests_peak = max(batch_requests_peak, int(snapshot.get("batch_requests_peak", 0))) + return { + "worker_count": int(self.worker_count), + "pending": int(pending), + "pending_peak": int(pending_peak), + "total_batches": int(total_batches), + "total_tasks": int(total_tasks), + "total_rows": int(total_rows), + "avg_queue_wait_ms": float(avg_queue_wait_ms / total_tasks) if total_tasks > 0 else 0.0, + "batch_rows_peak": int(batch_rows_peak), + "batch_requests_peak": int(batch_requests_peak), + "shards": shard_snapshots, + } + + +class G2PWCudaConverter(_G2PWBaseOnnxConverter): + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + super().__init__( + model_dir=model_dir, + style=style, + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + self.runtime = G2PWRuntimePool() + self.backend = "cuda" + primary_runtime = self.runtime.shards[0] + self.device = f"cuda:{primary_runtime.device_ordinal}" + self.checkpoint_path = str(primary_runtime.weights_path) + self.providers = ["g2pw-cu"] + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + probs = self.runtime.run(model_input) + preds = np.argmax(probs, axis=1).tolist() + confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist() + return [self.labels[pred] for pred in preds], confidences + + def _predict_with_profile(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float], Dict[str, float]]: + started = time.perf_counter() + probs, runtime_profile = self.runtime.run_with_profile(model_input) + preds = np.argmax(probs, axis=1).tolist() + confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist() + profile = dict(runtime_profile) + profile["g2pw_runtime_total_ms"] = float((time.perf_counter() - started) * 1000.0) + profile["g2pw_predict_ms"] = float(profile["g2pw_runtime_total_ms"]) + return [self.labels[pred] for pred in preds], confidences, profile + + def snapshot(self) -> Dict[str, float | int | bool]: + return dict(self.runtime.snapshot()) diff --git a/GPT_SoVITS/text/g2pw/g2pw.py b/GPT_SoVITS/text/g2pw/g2pw.py index 08525e91..ccd05a1b 100644 --- a/GPT_SoVITS/text/g2pw/g2pw.py +++ b/GPT_SoVITS/text/g2pw/g2pw.py @@ -8,6 +8,7 @@ from pypinyin.core import Pinyin, Style from pypinyin.seg.simpleseg import simple_seg from pypinyin.converter import UltimateConverter from pypinyin.contrib.tone_convert import to_tone +from .cuda_api import G2PWCudaConverter from .onnx_api import G2PWOnnxConverter current_file_path = os.path.dirname(__file__) @@ -27,12 +28,36 @@ class G2PWPinyin(Pinyin): tone_sandhi=False, **kwargs, ): - self._g2pw = G2PWOnnxConverter( - model_dir=model_dir, - style="pinyin", - model_source=model_source, - enable_non_tradional_chinese=enable_non_tradional_chinese, - ) + backend = os.environ.get("GPTSOVITS_G2PW_BACKEND", "cuda").strip().lower() + last_error = None + self._g2pw = None + if backend in {"cuda", "auto"}: + try: + self._g2pw = G2PWCudaConverter( + model_dir=model_dir, + style="pinyin", + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + except Exception as exc: + last_error = exc + strict_mode = os.environ.get("GPTSOVITS_G2PW_CUDA_STRICT", "0").strip().lower() in { + "1", + "true", + "yes", + "on", + } + if backend == "cuda" and strict_mode: + raise + if self._g2pw is None: + self._g2pw = G2PWOnnxConverter( + model_dir=model_dir, + style="pinyin", + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + if last_error is not None: + print(f"[g2pw] cuda backend unavailable, fallback to onnx: {last_error}") self._converter = Converter( self._g2pw, v_to_u=v_to_u, diff --git a/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp b/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp new file mode 100644 index 00000000..dc8f29a8 --- /dev/null +++ b/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp @@ -0,0 +1,183 @@ +#include +#include +#include +#include + +#include "g2pw/runtime.h" + +namespace { + +struct G2PWRuntimeHandle { + std::unique_ptr runtime; + std::string last_error; + int num_labels = 0; +}; + +void SetError(G2PWRuntimeHandle* handle, const g2pw::Status& status) { + if (handle == nullptr) { + return; + } + handle->last_error = status.message; +} + +g2pw::RuntimeConfig BuildConfig( + int device_ordinal, + int max_batch_size, + int max_seq_len, + int full_graph_cache_limit, + int tail_graph_cache_limit, + int allow_tensor_cores, + int use_cublaslt_bias_epilogue, + int enable_profiling, + int enable_cuda_graph, + int dump_graph_cache_stats, + int gemm_precision) { + g2pw::RuntimeConfig config{}; + config.device_ordinal = device_ordinal; + config.max_batch_size = max_batch_size; + config.max_seq_len = max_seq_len; + config.full_graph_cache_limit = full_graph_cache_limit; + config.tail_graph_cache_limit = tail_graph_cache_limit; + config.allow_tensor_cores = allow_tensor_cores != 0; + config.use_cublaslt_bias_epilogue = use_cublaslt_bias_epilogue != 0; + config.enable_profiling = enable_profiling != 0; + config.enable_cuda_graph = enable_cuda_graph != 0; + config.dump_graph_cache_stats = dump_graph_cache_stats != 0; + switch (gemm_precision) { + case 1: + config.gemm_precision = g2pw::GemmPrecision::kFp16; + break; + case 2: + config.gemm_precision = g2pw::GemmPrecision::kBf16; + break; + default: + config.gemm_precision = g2pw::GemmPrecision::kFp32; + break; + } + return config; +} + +} // namespace + +extern "C" { + +void* g2pw_runtime_create( + const char* manifest_path, + const char* binary_path, + int device_ordinal, + int max_batch_size, + int max_seq_len, + int full_graph_cache_limit, + int tail_graph_cache_limit, + int allow_tensor_cores, + int use_cublaslt_bias_epilogue, + int enable_profiling, + int enable_cuda_graph, + int dump_graph_cache_stats, + int gemm_precision) { + auto* handle = new G2PWRuntimeHandle(); + try { + if (manifest_path == nullptr || binary_path == nullptr) { + handle->last_error = "manifest_path and binary_path must be non-null"; + return handle; + } + g2pw::RuntimeConfig config = BuildConfig( + device_ordinal, + max_batch_size, + max_seq_len, + full_graph_cache_limit, + tail_graph_cache_limit, + allow_tensor_cores, + use_cublaslt_bias_epilogue, + enable_profiling, + enable_cuda_graph, + dump_graph_cache_stats, + gemm_precision); + g2pw::Status status = g2pw::Runtime::Create( + config, + std::string(manifest_path), + std::string(binary_path), + &handle->runtime); + if (!status.ok()) { + SetError(handle, status); + return handle; + } + handle->num_labels = handle->runtime != nullptr ? handle->runtime->weights().manifest().num_labels : 0; + handle->last_error.clear(); + return handle; + } catch (const std::exception& exc) { + handle->last_error = exc.what(); + return handle; + } catch (...) { + handle->last_error = "unknown exception"; + return handle; + } +} + +void g2pw_runtime_destroy(void* raw_handle) { + auto* handle = static_cast(raw_handle); + delete handle; +} + +const char* g2pw_runtime_last_error(void* raw_handle) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr) { + return "invalid runtime handle"; + } + return handle->last_error.c_str(); +} + +int g2pw_runtime_num_labels(void* raw_handle) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr || handle->runtime == nullptr) { + return 0; + } + return handle->num_labels; +} + +int g2pw_runtime_run( + void* raw_handle, + const std::int64_t* input_ids, + const std::int64_t* token_type_ids, + const std::int64_t* attention_mask, + const float* phoneme_mask, + const std::int64_t* char_ids, + const std::int64_t* position_ids, + std::int32_t batch_size, + std::int32_t seq_len, + float* probs) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr || handle->runtime == nullptr) { + return static_cast(g2pw::StatusCode::kInvalidArgument); + } + try { + g2pw::InferenceInputs inputs{}; + inputs.input_ids = input_ids; + inputs.token_type_ids = token_type_ids; + inputs.attention_mask = attention_mask; + inputs.phoneme_mask = phoneme_mask; + inputs.char_ids = char_ids; + inputs.position_ids = position_ids; + inputs.batch_size = batch_size; + inputs.seq_len = seq_len; + + g2pw::InferenceOutputs outputs{}; + outputs.probs = probs; + + const g2pw::Status status = handle->runtime->Run(inputs, outputs); + if (!status.ok()) { + SetError(handle, status); + return static_cast(status.code); + } + handle->last_error.clear(); + return static_cast(g2pw::StatusCode::kOk); + } catch (const std::exception& exc) { + handle->last_error = exc.what(); + return static_cast(g2pw::StatusCode::kInternalError); + } catch (...) { + handle->last_error = "unknown exception"; + return static_cast(g2pw::StatusCode::kInternalError); + } +} + +} diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 3c2b0169..f6d7fab7 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -3,6 +3,7 @@ import json import os +import time import warnings import zipfile from typing import Any, Dict, List, Tuple @@ -71,6 +72,23 @@ def _find_first_existing_file(*paths: str) -> str: raise FileNotFoundError(f"Files not found: {paths}") +def _resolve_tokenizer_source(model_source: str | None) -> str: + candidate_paths = [] + if model_source: + candidate_paths.append(model_source) + repo_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) + candidate_paths.extend( + [ + os.path.join(repo_root, "pretrained_models", "g2pw-chinese"), + os.path.join(repo_root, "pretrained_models", "chinese-roberta-wwm-ext-large"), + ] + ) + for candidate in candidate_paths: + if candidate and os.path.exists(candidate): + return candidate + return model_source or "bert-base-chinese" + + def download_and_decompress(model_dir: str = "G2PWModel/"): if not os.path.exists(model_dir): parent_directory = os.path.dirname(model_dir) @@ -106,9 +124,9 @@ class _G2PWBaseOnnxConverter: self.model_dir = download_and_decompress(model_dir) self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True) - self.model_source = model_source if model_source else self.config.model_source + self.model_source = _resolve_tokenizer_source(model_source if model_source else self.config.model_source) self.enable_opencc = enable_non_tradional_chinese - self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_source, local_files_only=True) polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt") monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt") @@ -200,6 +218,10 @@ class _G2PWBaseOnnxConverter: return None def __call__(self, sentences: List[str]) -> List[List[str]]: + results, _profile = self.predict_sentences_with_profile(sentences) + return results + + def predict_sentences_with_profile(self, sentences: List[str]) -> Tuple[List[List[str]], Dict[str, float]]: if isinstance(sentences, str): sentences = [sentences] @@ -213,7 +235,7 @@ class _G2PWBaseOnnxConverter: texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: - return partial_results + return partial_results, {} model_input = prepare_onnx_input( tokenizer=self.tokenizer, @@ -229,12 +251,21 @@ class _G2PWBaseOnnxConverter: ) if not model_input: - return partial_results + return partial_results, {} + predict_profile: Dict[str, float] = {} if self.enable_sentence_dedup: - preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts) + preds, _confidences, predict_profile = self._predict_with_sentence_dedup_profiled( + model_input=model_input, + texts=texts, + ) else: - preds, _confidences = self._predict(model_input=model_input) + if hasattr(self, "_predict_with_profile"): + preds, _confidences, predict_profile = self._predict_with_profile(model_input=model_input) + else: + predict_started = time.perf_counter() + preds, _confidences = self._predict(model_input=model_input) + predict_profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_started) * 1000.0) if self.config.use_char_phoneme: preds = [pred.split(" ")[1] for pred in preds] @@ -243,7 +274,7 @@ class _G2PWBaseOnnxConverter: for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds): results[sent_id][query_id] = self.style_convert_func(pred) - return results + return results, predict_profile def _prepare_data( self, sentences: List[str] @@ -314,6 +345,52 @@ class _G2PWBaseOnnxConverter: return preds, confidences + def _predict_with_sentence_dedup_profiled( + self, + model_input: Dict[str, Any], + texts: List[str], + ) -> Tuple[List[str], List[float], Dict[str, float]]: + if len(texts) <= 1: + if hasattr(self, "_predict_with_profile"): + return self._predict_with_profile(model_input=model_input) + predict_started = time.perf_counter() + preds, confidences = self._predict(model_input=model_input) + return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)} + + grouped_indices: Dict[str, List[int]] = {} + for idx, text in enumerate(texts): + grouped_indices.setdefault(text, []).append(idx) + + if all(len(indices) == 1 for indices in grouped_indices.values()): + if hasattr(self, "_predict_with_profile"): + return self._predict_with_profile(model_input=model_input) + predict_started = time.perf_counter() + preds, confidences = self._predict(model_input=model_input) + return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)} + + preds: List[str] = [""] * len(texts) + confidences: List[float] = [0.0] * len(texts) + merged_profile: Dict[str, float] = {} + for indices in grouped_indices.values(): + group_input = {name: value[indices] for name, value in model_input.items()} + if len(indices) > 1: + for name in ("input_ids", "token_type_ids", "attention_masks"): + group_input[name] = group_input[name][:1] + if hasattr(self, "_predict_with_profile"): + group_preds, group_confidences, group_profile = self._predict_with_profile(model_input=group_input) + for key, value in dict(group_profile or {}).items(): + merged_profile[key] = float(merged_profile.get(key, 0.0)) + float(value) + else: + predict_started = time.perf_counter() + group_preds, group_confidences = self._predict(model_input=group_input) + merged_profile["g2pw_predict_ms"] = float( + merged_profile.get("g2pw_predict_ms", 0.0) + (time.perf_counter() - predict_started) * 1000.0 + ) + for output_idx, pred, confidence in zip(indices, group_preds, group_confidences): + preds[output_idx] = pred + confidences[output_idx] = confidence + return preds, confidences, merged_profile + class G2PWOnnxConverter(_G2PWBaseOnnxConverter): def __init__( From bc1f3f32de89e1ec6e093b8e27a6c16842a67753 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Fri, 13 Mar 2026 02:03:25 +0800 Subject: [PATCH 22/24] Enhance audio processing in TTS framework with resampling and profiling improvements Add resampling capabilities using torchaudio to prepare reference audio at 16kHz, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms to optimize resource usage. Update batch processing methods to include timing metrics for profiling, enhancing the ability to monitor and improve performance in the TTS system. This update improves the maintainability and efficiency of audio preparation workflows. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 66 +++++++++++++++++-- .../prepare_ref_semantic_batch_worker.py | 32 ++++++--- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 12 ++++ 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 81c1ca1e..0140eff3 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -945,6 +945,13 @@ class TTS: codes = self.vits_model.extract_latent(hubert_feature) return codes[0, 0].to(self.configs.device) + @torch.inference_mode() + def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor): + forward_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + return prompt_semantic, forward_ms + @torch.inference_mode() def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): cpu_prepare_start = time.perf_counter() @@ -954,9 +961,7 @@ class TTS: zero_wav_samples=int(self.configs.sampling_rate * 0.3), ) cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 - forward_start = time.perf_counter() - prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) return prompt_semantic, cpu_prepare_ms, forward_ms @torch.inference_mode() @@ -1011,10 +1016,17 @@ class TTS: raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) load_ms = (time.perf_counter() - load_start) * 1000.0 if self.prepare_ref_semantic_batch_worker is None: + prompt_semantic_cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0 with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() - prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = ( - self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k( + wav16k ) prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 ref_spec_start = time.perf_counter() @@ -1025,6 +1037,10 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), "prompt_semantic_scatter_ms": 0.0, @@ -1046,6 +1062,18 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), @@ -1073,6 +1101,10 @@ class TTS: prompt_semantic_profile = { "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": 0.0, + "prompt_semantic_batch_dispatch_delay_ms": 0.0, "prompt_semantic_cpu_prepare_ms": 0.0, "prompt_semantic_forward_ms": 0.0, "prompt_semantic_scatter_ms": 0.0, @@ -1116,6 +1148,18 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), @@ -1193,6 +1237,18 @@ class TTS: "audio_stage_inflight_peak": float(audio_stage_inflight_peak), "prompt_semantic_ms": float(prompt_semantic_ms), "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index 64ca133f..d46352a7 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -51,6 +51,7 @@ class RefSemanticTask: raw_sr: int task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) + batch_popped_at: float = 0.0 done_event: threading.Event = field(default_factory=threading.Event) done_loop: asyncio.AbstractEventLoop | None = None done_future: asyncio.Future | None = None @@ -170,12 +171,14 @@ class PrepareRefSemanticBatchWorker: "max_batch_samples": self.max_batch_samples, } - def _collect_batch(self) -> List[RefSemanticTask]: + def _collect_batch(self) -> tuple[List[RefSemanticTask], float]: with self.condition: while not self.pending_tasks: self.condition.wait() - batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] + first_task = self.pending_tasks.popleft() + first_task.batch_popped_at = time.perf_counter() + batch: List[RefSemanticTask] = [first_task] batch_samples = self._estimate_task_samples(batch[0]) deadline = time.perf_counter() + self.batch_window_s @@ -190,7 +193,9 @@ class PrepareRefSemanticBatchWorker: next_samples = self._estimate_task_samples(next_task) if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: break - batch.append(self.pending_tasks.popleft()) + popped_task = self.pending_tasks.popleft() + popped_task.batch_popped_at = time.perf_counter() + batch.append(popped_task) batch_samples += next_samples self.active_batch_size = len(batch) @@ -199,7 +204,7 @@ class PrepareRefSemanticBatchWorker: self.active_batch_peak = self.active_batch_size if self.active_batch_samples > self.active_batch_samples_peak: self.active_batch_samples_peak = self.active_batch_samples - return batch + return batch, time.perf_counter() def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: with self.condition: @@ -219,7 +224,7 @@ class PrepareRefSemanticBatchWorker: return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) @torch.inference_mode() - def _run_batch(self, batch: List[RefSemanticTask]) -> None: + def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None: batch_started = time.perf_counter() prepared_start = time.perf_counter() prepared_wavs = [ @@ -268,8 +273,19 @@ class PrepareRefSemanticBatchWorker: try: code_len = int(code_lengths[batch_index].item()) task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0) + batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0) + stage_limiter_wait_ms = float(limiter_stats["wait_ms"]) task.profile = { - "prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "prompt_semantic_wait_ms": worker_queue_wait_ms + + batch_collect_wait_ms + + stage_limiter_wait_ms, + "prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms, + "prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms, + "prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms, + "prompt_semantic_batch_dispatch_delay_ms": max( + 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 + ), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), "prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_scatter_ms": 0.0, @@ -289,9 +305,9 @@ class PrepareRefSemanticBatchWorker: def _run_loop(self) -> None: while True: - batch = self._collect_batch() + batch, batch_collected_at = self._collect_batch() try: - self._run_batch(batch) + self._run_batch(batch, batch_collected_at) except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index e993a1ef..43290af7 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -305,6 +305,18 @@ def build_request_state_from_parts( "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), From c94de2f2cb748b25b84d5d990532a0153e76d954 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Fri, 13 Mar 2026 16:45:00 +0800 Subject: [PATCH 23/24] Enhance TTS audio processing with improved resampling and profiling metrics Refactor the audio preparation workflow to utilize torchaudio for resampling, replacing librosa for better performance. Introduce a caching mechanism for resampling transforms and update the PrepareRefSemanticBatchWorker to include detailed timing metrics for profiling. Additionally, implement a new CPU limiter for managing resource allocation during audio processing. These changes improve the efficiency and maintainability of the TTS system. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 70 +++++++++++++++---- .../prepare_ref_semantic_batch_worker.py | 35 ++++++++-- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 5 ++ 3 files changed, 90 insertions(+), 20 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0140eff3..78bf7178 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -454,6 +454,7 @@ class TTS: } self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) + self.prepare_ref_audio_cpu_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_CPU_SLOTS", "8"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None self.prepare_text_cpu_worker = None @@ -952,15 +953,36 @@ class TTS: forward_ms = (time.perf_counter() - forward_start) * 1000.0 return prompt_semantic, forward_ms + @torch.inference_mode() + def _prepare_prompt_semantic_wav16k_profile(self, raw_audio: torch.Tensor, raw_sr: int): + limiter = getattr(self, "prepare_ref_audio_cpu_limiter", None) + if limiter is None: + cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + return wav16k, cpu_prepare_ms, {"wait_ms": 0.0, "slots": 0.0, "peak_inflight": 0.0} + + with limiter.enter() as limiter_stats: + cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + return wav16k, cpu_prepare_ms, { + "wait_ms": float(limiter_stats.get("wait_ms", 0.0)), + "slots": float(limiter_stats.get("slots", 0.0)), + "peak_inflight": float(limiter_stats.get("peak_inflight", 0.0)), + } + @torch.inference_mode() def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): - cpu_prepare_start = time.perf_counter() - wav16k = prepare_prompt_semantic_wav16k( - raw_audio=raw_audio, - raw_sr=raw_sr, - zero_wav_samples=int(self.configs.sampling_rate * 0.3), - ) - cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + wav16k, cpu_prepare_ms, _ = self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) return prompt_semantic, cpu_prepare_ms, forward_ms @@ -1016,13 +1038,9 @@ class TTS: raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) load_ms = (time.perf_counter() - load_start) * 1000.0 if self.prepare_ref_semantic_batch_worker is None: - prompt_semantic_cpu_prepare_start = time.perf_counter() - wav16k = prepare_prompt_semantic_wav16k( - raw_audio=raw_audio, - raw_sr=raw_sr, - zero_wav_samples=int(self.configs.sampling_rate * 0.3), + wav16k, prompt_semantic_cpu_prepare_ms, prompt_semantic_cpu_limiter_stats = ( + self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) ) - prompt_semantic_cpu_prepare_ms = (time.perf_counter() - prompt_semantic_cpu_prepare_start) * 1000.0 with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k( @@ -1037,6 +1055,11 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_wait_ms": float(prompt_semantic_cpu_limiter_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(prompt_semantic_cpu_limiter_stats.get("slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_cpu_limiter_stats.get("peak_inflight", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": 0.0, "prompt_semantic_batch_collect_wait_ms": 0.0, "prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]), @@ -1062,6 +1085,15 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_slots": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0) + ), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": float( prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) ), @@ -1101,6 +1133,9 @@ class TTS: prompt_semantic_profile = { "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_slots": float(getattr(self.prepare_ref_audio_cpu_limiter, "slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": 0.0, "prompt_semantic_worker_queue_wait_ms": 0.0, "prompt_semantic_batch_collect_wait_ms": 0.0, "prompt_semantic_stage_limiter_wait_ms": 0.0, @@ -1148,6 +1183,15 @@ class TTS: "audio_stage_inflight_peak": audio_stage_inflight_peak, "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_slots": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0) + ), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": float( prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) ), diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index d46352a7..ff5591b2 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -1,4 +1,5 @@ import asyncio +import os import threading import time import uuid @@ -6,28 +7,48 @@ from collections import deque from dataclasses import dataclass, field from typing import Deque, Dict, List, Tuple -import librosa -import numpy as np import torch +import torchaudio REF_AUDIO_MIN_SAMPLES_16K = 48000 REF_AUDIO_MAX_SAMPLES_16K = 160000 +_RESAMPLE_CACHE_LOCK = threading.Lock() +_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {} + + +def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample: + device_key = str(device) + key = (int(orig_sr), int(target_sr), device_key) + with _RESAMPLE_CACHE_LOCK: + transform = _RESAMPLE_CACHE.get(key) + if transform is None: + transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key) + _RESAMPLE_CACHE[key] = transform + return transform def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: + resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu" + if resample_device not in {"cpu", "cuda"}: + resample_device = "cpu" + if resample_device == "cuda" and not torch.cuda.is_available(): + resample_device = "cpu" wav_mono = raw_audio if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: wav_mono = wav_mono.mean(0, keepdim=True) - wav16k = wav_mono.squeeze(0).cpu().numpy() + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) if raw_sr != 16000: - wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000) + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: raise OSError("参考音频在3~10秒范围外,请更换!") - wav16k = np.ascontiguousarray(wav16k, dtype=np.float32) if zero_wav_samples > 0: - wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0) - return torch.from_numpy(wav16k) + wav16k = torch.cat( + [wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)], + dim=0, + ) + return wav16k.contiguous() def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 43290af7..73e2a2c7 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -305,6 +305,11 @@ def build_request_state_from_parts( "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(bundle_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float( + bundle_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), "prompt_semantic_worker_queue_wait_ms": float( bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) ), From 8a444c10b72aadeb1f8be9210b45f2a519d383e4 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Fri, 13 Mar 2026 16:45:38 +0800 Subject: [PATCH 24/24] Enhance TTS processing with new reference specification handling and profiling metrics Refactor the PrepareCoordinator and related components to improve the handling of reference specifications in the TTS system. Introduce new methods for building and extracting reference prompts and specifications, along with detailed profiling metrics for performance monitoring. Update the PrepareRefSemanticBatchWorker to include additional timing metrics and caching mechanisms for resampling. These changes enhance the efficiency and maintainability of the TTS framework, particularly in managing audio processing and reference data. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 23 +- .../TTS_infer_pack/prepare_coordinator.py | 537 +++++++++++++----- .../prepare_ref_semantic_batch_worker.py | 61 +- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 20 +- .../unified_engine_bridge_registry.py | 40 +- .../TTS_infer_pack/unified_engine_builder.py | 4 + .../unified_engine_component_policy.py | 38 +- .../unified_engine_component_runtime.py | 18 +- .../unified_engine_orchestration.py | 24 + .../TTS_infer_pack/unified_engine_stage.py | 15 + .../unified_engine_stage_dispatch.py | 4 + .../unified_engine_stage_executor.py | 4 + .../unified_engine_stage_finalize.py | 39 +- .../unified_engine_stage_prepare.py | 234 +++++++- .../unified_engine_worker_finalize.py | 4 +- .../unified_engine_worker_prepare.py | 61 +- .../unified_engine_worker_submit.py | 39 ++ GPT_SoVITS/text/g2pw/cuda_api.py | 15 + 18 files changed, 1006 insertions(+), 174 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 78bf7178..16bc8db8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -996,21 +996,39 @@ class TTS: return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + spec, audio, _, _, _ = self._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return spec, audio, raw_audio, raw_sr + + def _extract_ref_spec_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + profile = { + "ref_spec_to_device_ms": 0.0, + "ref_spec_main_resample_ms": 0.0, + "ref_spec_norm_ms": 0.0, + "ref_spec_spectrogram_ms": 0.0, + "ref_spec_post_resample_ms": 0.0, + } + to_device_start = time.perf_counter() raw_audio_device = raw_audio.to(self.configs.device).float() + profile["ref_spec_to_device_ms"] = (time.perf_counter() - to_device_start) * 1000.0 if raw_sr != self.configs.sampling_rate: + resample_start = time.perf_counter() audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) + profile["ref_spec_main_resample_ms"] = (time.perf_counter() - resample_start) * 1000.0 else: audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) + norm_start = time.perf_counter() maxx = audio.abs().max() if maxx > 1: audio /= min(2, maxx) + profile["ref_spec_norm_ms"] = (time.perf_counter() - norm_start) * 1000.0 + spec_start = time.perf_counter() spec = spectrogram_torch( audio, self.configs.filter_length, @@ -1019,15 +1037,18 @@ class TTS: self.configs.win_length, center=False, ) + profile["ref_spec_spectrogram_ms"] = (time.perf_counter() - spec_start) * 1000.0 if self.configs.is_half: spec = spec.half() if self.is_v2pro == True: + post_resample_start = time.perf_counter() audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) + profile["ref_spec_post_resample_ms"] = (time.perf_counter() - post_resample_start) * 1000.0 if self.configs.is_half: audio = audio.half() else: audio = None - return spec, audio, raw_audio, raw_sr + return spec, audio, raw_audio, raw_sr, profile def extract_ref_spec(self, ref_audio_path: str): raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py index 65bcbb51..e74e0de4 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -228,8 +228,74 @@ class PrepareCoordinator: def _load_ref_audio_raw(self, ref_audio_path: str): return self.tts._load_ref_audio_raw(ref_audio_path) + def _build_ref_prompt_semantic_from_raw(self, raw_audio, raw_sr: int): + load_profile = {"audio_load_ms": 0.0} + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + prompt_semantic, worker_profile = self.tts.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + **load_profile, + "audio_stage_wait_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)), + "audio_stage_slots": float(worker_profile.get("prompt_semantic_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(worker_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_ms": float( + worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + + worker_profile.get("prompt_semantic_forward_ms", 0.0) + + worker_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + **{key: float(value) for key, value in worker_profile.items()}, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_scatter_ms", 0.0)), + }, + } + wav16k, cpu_prepare_ms, limiter_stats = self.tts._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) + with self.tts.prepare_ref_audio_stage_limiter.enter() as stage_stats: + prompt_semantic, forward_ms = self.tts._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": 0.0, + "audio_stage_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "audio_stage_slots": float(stage_stats.get("slots", 0.0)), + "audio_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float(limiter_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(limiter_stats.get("slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float(limiter_stats.get("peak_inflight", 0.0)), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": 0.0, + "prompt_semantic_h2d_ms": 0.0, + "prompt_semantic_ssl_forward_ms": 0.0, + "prompt_semantic_hidden_length_ms": 0.0, + "prompt_semantic_extract_latent_ms": 0.0, + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(stage_stats.get("slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(cpu_prepare_ms + forward_ms + stage_stats.get("wait_ms", 0.0)), + }, + } + def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): - return self.tts._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + spec, audio, _, _, profile = self.tts._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return (spec, audio), profile @staticmethod def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: @@ -523,7 +589,7 @@ class PrepareCoordinator: finally: self.text_feature_gate.release() - async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: + async def _run_ref_prompt_semantic_stage(self, ref_audio_path: str) -> ProfiledResult: if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: submit_at = time.perf_counter() started_at = float(submit_at) @@ -538,19 +604,7 @@ class PrepareCoordinator: prompt_semantic_task = asyncio.create_task( self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) ) - await self.ref_spec_gate.acquire() - try: - ref_spec_task = asyncio.create_task( - self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) - ) - (prompt_semantic, prompt_semantic_profile), ref_spec_profiled = await asyncio.gather( - prompt_semantic_task, - ref_spec_task, - ) - finally: - self.ref_spec_gate.release() - - refer_spec = ref_spec_profiled.result + prompt_semantic, prompt_semantic_profile = await prompt_semantic_task limiter_snapshot = ( self.tts.prepare_ref_audio_stage_limiter.snapshot() if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None @@ -561,21 +615,15 @@ class PrepareCoordinator: + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) ) - audio_stage_wait_ms = ( - float(load_profiled.queue_ms) - + float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) - + float(ref_spec_profiled.queue_ms) - ) finished_at = time.perf_counter() result = { "prompt_semantic": prompt_semantic, - "refer_spec": refer_spec, "raw_audio": raw_audio, "raw_sr": raw_sr, "profile": { "audio_load_queue_ms": float(load_profiled.queue_ms), "audio_load_ms": float(load_profiled.run_ms), - "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), "audio_stage_slots": float( max( float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -593,6 +641,17 @@ class PrepareCoordinator: "prompt_semantic_cpu_prepare_ms": float( prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) ), + "prompt_semantic_pack_ms": float(prompt_semantic_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(prompt_semantic_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_ssl_forward_ms", 0.0) + ), + "prompt_semantic_hidden_length_ms": float( + prompt_semantic_profile.get("prompt_semantic_hidden_length_ms", 0.0) + ), + "prompt_semantic_extract_latent_ms": float( + prompt_semantic_profile.get("prompt_semantic_extract_latent_ms", 0.0) + ), "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -603,14 +662,10 @@ class PrepareCoordinator: "prompt_semantic_batch_samples": float( prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) ), - "ref_spec_wait_ms": float(ref_spec_profiled.queue_ms), - "ref_spec_ms": float(ref_spec_profiled.run_ms), "bundle_total_ms": float( load_profiled.queue_ms + load_profiled.run_ms + prompt_semantic_ms - + ref_spec_profiled.queue_ms - + ref_spec_profiled.run_ms ), }, } @@ -623,21 +678,26 @@ class PrepareCoordinator: await self.ref_audio_gate.acquire() try: - if hasattr(self.tts, "extract_ref_audio_bundle_async"): - submit_at = time.perf_counter() - started_at = time.perf_counter() - result = await self.tts.extract_ref_audio_bundle_async(ref_audio_path) - finished_at = time.perf_counter() - return ProfiledResult( - result=result, - submit_at=float(submit_at), - started_at=float(started_at), - finished_at=float(finished_at), - ) - return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + raw_audio, raw_sr = load_profiled.result + submit_at = time.perf_counter() + started_at = time.perf_counter() + result = await asyncio.to_thread(self._build_ref_prompt_semantic_from_raw, raw_audio, raw_sr) + result.setdefault("profile", {}) + result["profile"]["audio_load_queue_ms"] = float(load_profiled.queue_ms) + result["profile"]["audio_load_ms"] = float(load_profiled.run_ms) + finished_at = time.perf_counter() + return ProfiledResult(result=result, submit_at=float(submit_at), started_at=float(started_at), finished_at=float(finished_at)) finally: self.ref_audio_gate.release() + async def _run_ref_spec_stage(self, raw_audio, raw_sr: int) -> ProfiledResult: + await self.ref_spec_gate.acquire() + try: + return await self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) + finally: + self.ref_spec_gate.release() + def _release_split_stage_slot(self) -> None: self._mark_leave() self._inflight_gate.release() @@ -682,101 +742,318 @@ class PrepareCoordinator: cpu_stage: PreparedCpuStage, ) -> tuple[T2SRequestState, float, float]: try: - g2pw_pair_start = time.perf_counter() - g2pw_pair_task = asyncio.create_task( - self._run_g2pw_pair_stage( - cpu_stage.prompt_cpu_profiled.result, - cpu_stage.target_cpu_profiled.result, - ) - ) - ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(cpu_stage.spec.ref_audio_path))) - (prompt_g2pw_profiled, target_g2pw_profiled), ref_audio_profiled = await asyncio.gather( - g2pw_pair_task, - ref_audio_task, - ) - g2pw_pair_end = time.perf_counter() - text_pair_start = time.perf_counter() - text_feature_pair_task = asyncio.create_task( - self._run_text_feature_pair_stage( - prompt_g2pw_profiled.result, - target_g2pw_profiled.result, - cpu_stage.prompt_cpu_profiled.run_ms, - cpu_stage.target_cpu_profiled.run_ms, - prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), - target_base_profile=dict(target_g2pw_profiled.profile or {}), - ) - ) - prompt_feature_profiled, target_feature_profiled = await text_feature_pair_task - text_pair_end = time.perf_counter() - state = build_request_state_from_parts( - tts=self.tts, - spec=cpu_stage.spec, - prompt_text=cpu_stage.prompt_text, - text=cpu_stage.text, - prompt_result=prompt_feature_profiled.result, - target_result=target_feature_profiled.result, - ref_audio_bundle=ref_audio_profiled.result, - prepare_start=cpu_stage.prepare_start, - prepare_sync_start=cpu_stage.prepare_start, - profile_overrides={ - "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), - "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, - "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), - "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), - "g2pw_pair_ms": max(0.0, (g2pw_pair_end - g2pw_pair_start) * 1000.0), - "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, - "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, - "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), - "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), - "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), - "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, - "text_g2pw_run_ms": target_g2pw_profiled.run_ms, - "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), - "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), - "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), - "prompt_text_parallel_future_wait_ms": 0.0, - "prompt_text_parallel_future_executor_queue_ms": 0.0, - "prompt_text_parallel_future_run_ms": 0.0, - "prompt_text_parallel_future_finish_after_submit_ms": 0.0, - "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, - "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, - "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, - "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, - "prompt_text_cpu_admission_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) - ), - "prompt_text_cpu_backpressure_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) - ), - "prompt_text_cpu_capacity_wait_ms": float( - (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) - ), - "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, - "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, - "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, - "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, - "text_cpu_admission_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) - ), - "text_cpu_backpressure_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) - ), - "text_cpu_capacity_wait_ms": float( - (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) - ), - "text_feature_queue_ms": target_feature_profiled.queue_ms, - "text_feature_run_ms": target_feature_profiled.run_ms, - "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, - "ref_audio_task_run_ms": ref_audio_profiled.run_ms, - "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), - "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), + phase_one = await self._prepare_gpu_phase_one(cpu_stage) + phase_two = await self._prepare_gpu_phase_two(cpu_stage, phase_one) + return self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 0.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one["phase_wall_ms"]), + "engine_prepare_audio_phase_batch_size": 1.0, + "engine_prepare_text_phase_wall_ms": float(phase_two["phase_wall_ms"]), + "engine_prepare_text_phase_batch_size": 1.0, }, ) - prepare_exec_finished_at = time.perf_counter() - state.prepare_profile["executor_run_wall_ms"] = max( - 0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0 + finally: + self._release_split_stage_slot() + + async def _prepare_gpu_phase_one(self, cpu_stage: PreparedCpuStage) -> Dict[str, Any]: + phase_start = time.perf_counter() + g2pw_pair_task = asyncio.create_task( + self._run_g2pw_pair_stage( + cpu_stage.prompt_cpu_profiled.result, + cpu_stage.target_cpu_profiled.result, ) - return state, cpu_stage.prepare_start, prepare_exec_finished_at + ) + ref_audio_task = asyncio.create_task(self._run_ref_prompt_semantic_stage(str(cpu_stage.spec.ref_audio_path))) + prompt_g2pw_profiled, target_g2pw_profiled = await g2pw_pair_task + g2pw_pair_end = time.perf_counter() + ref_audio_profiled = await ref_audio_task + phase_end = time.perf_counter() + return { + "prompt_g2pw_profiled": prompt_g2pw_profiled, + "target_g2pw_profiled": target_g2pw_profiled, + "ref_audio_profiled": ref_audio_profiled, + "ref_spec_result": None, + "g2pw_pair_ms": max(0.0, (g2pw_pair_end - phase_start) * 1000.0), + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + async def _prepare_gpu_phase_two( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + ) -> Dict[str, Any]: + phase_start = time.perf_counter() + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + prompt_feature_profiled, target_feature_profiled = await self._run_text_feature_pair_stage( + prompt_g2pw_profiled.result, + target_g2pw_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, + prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), + target_base_profile=dict(target_g2pw_profiled.profile or {}), + ) + phase_end = time.perf_counter() + return { + "prompt_feature_profiled": prompt_feature_profiled, + "target_feature_profiled": target_feature_profiled, + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + def _build_gpu_prepare_result( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + ref_audio_profiled = phase_one["ref_audio_profiled"] + ref_spec_result = phase_one.get("ref_spec_result") + prompt_feature_profiled = phase_two["prompt_feature_profiled"] + target_feature_profiled = phase_two["target_feature_profiled"] + profile_overrides = { + "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, + "prepare_submit_ts": float(cpu_stage.prepare_submit_at), + "prepare_cpu_start_ts": float(cpu_stage.prepare_start), + "prepare_cpu_done_ts": float( + max(cpu_stage.prompt_cpu_profiled.finished_at, cpu_stage.target_cpu_profiled.finished_at) + ), + "prompt_text_cpu_start_ts": float(cpu_stage.prompt_cpu_profiled.started_at), + "prompt_text_cpu_end_ts": float(cpu_stage.prompt_cpu_profiled.finished_at), + "text_cpu_start_ts": float(cpu_stage.target_cpu_profiled.started_at), + "text_cpu_end_ts": float(cpu_stage.target_cpu_profiled.finished_at), + "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), + "text_feature_pair_ms": float(phase_two["phase_wall_ms"]), + "g2pw_pair_ms": float(phase_one["g2pw_pair_ms"]), + "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, + "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, + "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, + "text_g2pw_run_ms": target_g2pw_profiled.run_ms, + "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": 0.0, + "prompt_text_parallel_future_finish_after_submit_ms": 0.0, + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, + "prompt_text_cpu_admission_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "prompt_text_cpu_backpressure_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "prompt_text_cpu_capacity_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, + "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, + "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, + "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, + "text_cpu_admission_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "text_cpu_backpressure_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "text_cpu_capacity_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "text_feature_queue_ms": target_feature_profiled.queue_ms, + "text_feature_run_ms": target_feature_profiled.run_ms, + "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, + "ref_audio_task_run_ms": ref_audio_profiled.run_ms, + "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), + "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), + } + if extra_profile: + profile_overrides.update({key: float(value) for key, value in extra_profile.items()}) + ref_audio_bundle = dict(ref_audio_profiled.result) + ref_audio_profile = dict(ref_audio_bundle.get("profile", {})) + if ref_spec_result is not None: + refer_spec, ref_spec_profiled = ref_spec_result + ref_audio_bundle["refer_spec"] = refer_spec + ref_audio_profile.update( + { + "ref_spec_wait_ms": float(ref_spec_profiled.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profiled.get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(ref_spec_profiled.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(ref_spec_profiled.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(ref_spec_profiled.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(ref_spec_profiled.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(ref_spec_profiled.get("ref_spec_post_resample_ms", 0.0)), + } + ) + else: + ref_audio_bundle["refer_spec"] = None + ref_audio_profile.setdefault("ref_spec_wait_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_to_device_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_main_resample_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_norm_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_spectrogram_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_post_resample_ms", 0.0) + ref_audio_bundle["profile"] = ref_audio_profile + state = build_request_state_from_parts( + tts=self.tts, + spec=cpu_stage.spec, + prompt_text=cpu_stage.prompt_text, + text=cpu_stage.text, + prompt_result=prompt_feature_profiled.result, + target_result=target_feature_profiled.result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=cpu_stage.prepare_start, + prepare_sync_start=cpu_stage.prepare_start, + profile_overrides=profile_overrides, + ) + prepare_exec_finished_at = time.perf_counter() + state.prepare_profile["executor_run_wall_ms"] = max(0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0) + return state, cpu_stage.prepare_start, prepare_exec_finished_at + + async def prepare_ref_spec_stages_async( + self, + phase_ones: list[Dict[str, Any]], + ) -> list[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + async def _one(phase_one: Dict[str, Any]): + ref_audio_profiled = phase_one["ref_audio_profiled"] + raw_audio = ref_audio_profiled.result["raw_audio"] + raw_sr = int(ref_audio_profiled.result["raw_sr"]) + profiled = await self._run_ref_spec_stage(raw_audio, raw_sr) + refer_spec, profile = profiled.result + merged_profile = dict(profile) + merged_profile["ref_spec_wait_ms"] = float(profiled.queue_ms) + merged_profile["ref_spec_ms"] = float(profiled.run_ms) + return refer_spec, merged_profile + + if not phase_ones: + return [] + return list(await asyncio.gather(*[_one(phase_one) for phase_one in phase_ones], return_exceptions=True)) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + refer_spec, profile = ref_spec_result + state.refer_spec = refer_spec + state.prepare_profile["ref_spec_wait_ms"] = float(profile.get("ref_spec_wait_ms", 0.0)) + state.prepare_profile["ref_spec_ms"] = float(profile.get("ref_spec_ms", 0.0)) + state.prepare_profile["ref_spec_to_device_ms"] = float(profile.get("ref_spec_to_device_ms", 0.0)) + state.prepare_profile["ref_spec_main_resample_ms"] = float(profile.get("ref_spec_main_resample_ms", 0.0)) + state.prepare_profile["ref_spec_norm_ms"] = float(profile.get("ref_spec_norm_ms", 0.0)) + state.prepare_profile["ref_spec_spectrogram_ms"] = float(profile.get("ref_spec_spectrogram_ms", 0.0)) + state.prepare_profile["ref_spec_post_resample_ms"] = float(profile.get("ref_spec_post_resample_ms", 0.0)) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[tuple[T2SRequestState, float, float] | Exception]: + if not cpu_stages: + return [] + if len(cpu_stages) == 1: + single_stage = cpu_stages[0] + try: + return [await self.prepare_gpu_stage_profiled_async(single_stage)] + except Exception as exc: # noqa: PERF203 + return [exc] + + phase_one_started_at = time.perf_counter() + phase_one_results = await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + phase_one_finished_at = time.perf_counter() + phase_one_wall_ms = max(0.0, (phase_one_finished_at - phase_one_started_at) * 1000.0) + + outputs: list[tuple[T2SRequestState, float, float] | Exception | None] = [None] * len(cpu_stages) + pending_phase_two: list[tuple[int, PreparedCpuStage, Dict[str, Any]]] = [] + for index, (cpu_stage, phase_one) in enumerate(zip(cpu_stages, phase_one_results)): + if isinstance(phase_one, Exception): + outputs[index] = phase_one + self._release_split_stage_slot() + continue + pending_phase_two.append((index, cpu_stage, phase_one)) + + phase_two_started_at = time.perf_counter() + phase_two_results = await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for _, cpu_stage, phase_one in pending_phase_two], + return_exceptions=True, + ) + phase_two_finished_at = time.perf_counter() + phase_two_wall_ms = max(0.0, (phase_two_finished_at - phase_two_started_at) * 1000.0) + + for (index, cpu_stage, phase_one), phase_two in zip(pending_phase_two, phase_two_results): + try: + if isinstance(phase_two, Exception): + outputs[index] = phase_two + continue + outputs[index] = self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 1.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one_wall_ms), + "engine_prepare_audio_phase_batch_size": float(len(cpu_stages)), + "engine_prepare_text_phase_wall_ms": float(phase_two_wall_ms), + "engine_prepare_text_phase_batch_size": float(len(pending_phase_two)), + }, + ) + except Exception as exc: # noqa: PERF203 + outputs[index] = exc + finally: + self._release_split_stage_slot() + + return [item if item is not None else RuntimeError("prepare batch result missing") for item in outputs] + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[Dict[str, Any] | Exception]: + if not cpu_stages: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + ) + + async def prepare_gpu_text_phases_async( + self, + items: list[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> list[Dict[str, Any] | Exception]: + if not items: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for cpu_stage, phase_one in items], + return_exceptions=True, + ) + ) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self._build_gpu_prepare_result(cpu_stage, phase_one, phase_two, extra_profile=extra_profile) finally: self._release_split_stage_slot() diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py index ff5591b2..4628a2a2 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -15,6 +15,7 @@ REF_AUDIO_MIN_SAMPLES_16K = 48000 REF_AUDIO_MAX_SAMPLES_16K = 160000 _RESAMPLE_CACHE_LOCK = threading.Lock() _RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {} +_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {} def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample: @@ -28,6 +29,16 @@ def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.tran return transform +def _get_resample_stream(device: str) -> torch.cuda.Stream: + device_key = str(device) + with _RESAMPLE_CACHE_LOCK: + stream = _RESAMPLE_STREAM_CACHE.get(device_key) + if stream is None: + stream = torch.cuda.Stream(device=device_key) + _RESAMPLE_STREAM_CACHE[device_key] = stream + return stream + + def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu" if resample_device not in {"cpu", "cuda"}: @@ -37,10 +48,20 @@ def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wa wav_mono = raw_audio if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: wav_mono = wav_mono.mean(0, keepdim=True) - wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) - if raw_sr != 16000: - wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) - wav16k = wav16k.squeeze(0).contiguous() + if resample_device == "cuda": + stream = _get_resample_stream(resample_device) + with torch.cuda.stream(stream): + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() + stream.synchronize() + wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous() + else: + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: raise OSError("参考音频在3~10秒范围外,请更换!") if zero_wav_samples > 0: @@ -256,37 +277,56 @@ class PrepareRefSemanticBatchWorker: batch_samples = int(wav_lengths.sum().item()) max_wav_len = int(wav_lengths.max().item()) + pack_start = time.perf_counter() input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) for batch_index, wav in enumerate(prepared_wavs): wav_len = int(wav.shape[0]) input_values_cpu[batch_index, :wav_len] = wav attention_mask_cpu[batch_index, :wav_len] = 1 + pack_ms = (time.perf_counter() - pack_start) * 1000.0 limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + h2d_ms = 0.0 + ssl_forward_ms = 0.0 + hidden_length_ms = 0.0 + extract_latent_ms = 0.0 if self.stage_limiter is None: + h2d_start = time.perf_counter() input_values = input_values_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device) if self.is_half: input_values = input_values.half() - forward_start = time.perf_counter() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() codes = self.vits_model.extract_latent(hubert_feature) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 else: with self.stage_limiter.enter() as limiter_stats: + h2d_start = time.perf_counter() input_values = input_values_cpu.to(self.device) attention_mask = attention_mask_cpu.to(self.device) if self.is_half: input_values = input_values.half() - forward_start = time.perf_counter() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() codes = self.vits_model.extract_latent(hubert_feature) - forward_ms = (time.perf_counter() - forward_start) * 1000.0 + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 + forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms) code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) scatter_start = time.perf_counter() @@ -308,6 +348,11 @@ class PrepareRefSemanticBatchWorker: 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 ), "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": float(pack_ms), + "prompt_semantic_h2d_ms": float(h2d_ms), + "prompt_semantic_ssl_forward_ms": float(ssl_forward_ms), + "prompt_semantic_hidden_length_ms": float(hidden_length_ms), + "prompt_semantic_extract_latent_ms": float(extract_latent_ms), "prompt_semantic_forward_ms": float(forward_ms), "prompt_semantic_scatter_ms": 0.0, "prompt_semantic_calls": 1.0, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index 73e2a2c7..a4d462d0 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -55,7 +55,7 @@ class T2SRequestState: all_phones: torch.LongTensor all_bert_features: torch.Tensor prompt_semantic: torch.LongTensor - refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + refer_spec: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] raw_audio: torch.Tensor raw_sr: int @@ -188,7 +188,11 @@ def build_request_state_from_parts( ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic = ref_audio_bundle["prompt_semantic"].long() - spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + refer_spec_value = ref_audio_bundle.get("refer_spec") + if refer_spec_value in [None, ()]: + spec_audio, audio_16k = None, None + else: + spec_audio, audio_16k = refer_spec_value aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [] for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []): if aux_ref_audio_path in [None, ""]: @@ -323,6 +327,11 @@ def build_request_state_from_parts( bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) ), "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_pack_ms": float(bundle_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(bundle_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float(bundle_profile.get("prompt_semantic_ssl_forward_ms", 0.0)), + "prompt_semantic_hidden_length_ms": float(bundle_profile.get("prompt_semantic_hidden_length_ms", 0.0)), + "prompt_semantic_extract_latent_ms": float(bundle_profile.get("prompt_semantic_extract_latent_ms", 0.0)), "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), @@ -331,6 +340,11 @@ def build_request_state_from_parts( "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_ms": ref_spec_ms, + "ref_spec_to_device_ms": float(bundle_profile.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(bundle_profile.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(bundle_profile.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(bundle_profile.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(bundle_profile.get("ref_spec_post_resample_ms", 0.0)), "ref_audio_bundle_ms": ref_audio_bundle_ms, "tensorize_ms": tensorize_ms, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, @@ -352,7 +366,7 @@ def build_request_state_from_parts( all_phones=all_phones, all_bert_features=all_bert_features, prompt_semantic=prompt_semantic, - refer_spec=(spec_audio, audio_16k), + refer_spec=(None if spec_audio is None else (spec_audio, audio_16k)), aux_refer_specs=aux_refer_specs, raw_audio=raw_audio, raw_sr=raw_sr, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py index 88b8cc5d..f07250e1 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py @@ -21,6 +21,14 @@ class EngineRegistryBridgeFacade: def engine_prepare_queue_owner(self): return self.owner.engine_prepare_queue_owner + @property + def engine_prepare_text_queue_owner(self): + return self.owner.engine_prepare_text_queue_owner + + @property + def engine_prepare_ref_spec_queue_owner(self): + return self.owner.engine_prepare_ref_spec_queue_owner + @property def engine_finalize_queue_owner(self): return self.owner.engine_finalize_queue_owner @@ -82,7 +90,33 @@ class EngineRegistryBridgeFacade: return self.request_registry.snapshot() def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: - return self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16) + ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16) + return { + "waiting_count": int(audio_snapshot.get("waiting_count", 0)) + + int(text_snapshot.get("waiting_count", 0)) + + int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)), + "text_waiting_count": int(text_snapshot.get("waiting_count", 0)), + "ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])), + "text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])), + "ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])), + "peak_waiting": int( + max( + int(audio_snapshot.get("peak_waiting", 0)), + int(text_snapshot.get("peak_waiting", 0)), + int(ref_spec_snapshot.get("peak_waiting", 0)), + ) + ), + "total_submitted": int(audio_snapshot.get("total_submitted", 0)), + "total_completed": int(audio_snapshot.get("total_completed", 0)), + "text_total_submitted": int(text_snapshot.get("total_submitted", 0)), + "text_total_completed": int(text_snapshot.get("total_completed", 0)), + "ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)), + "ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)), + } def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) @@ -107,6 +141,8 @@ class EngineRegistryBridgeFacade: def _is_engine_drained(self) -> bool: prepare_empty = self.engine_prepare_queue_owner.is_drained() + prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained() + prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained() dispatch_empty = self.engine_dispatch_queue_owner.is_drained() finalize_empty = self.engine_finalize_queue_owner.is_drained() decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() @@ -114,6 +150,8 @@ class EngineRegistryBridgeFacade: worker_state = self.scheduler_worker.snapshot() return bool( prepare_empty + and prepare_text_empty + and prepare_ref_spec_empty and dispatch_empty and finalize_empty and decode_pending_empty diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py index 0e93c442..2cb1e175 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -110,6 +110,8 @@ class EngineCompositionBuilder: get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s, ) owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") @@ -119,6 +121,8 @@ class EngineCompositionBuilder: tts=owner.tts, scheduler_worker=owner.scheduler_worker, prepare_queue_owner=owner.engine_prepare_queue_owner, + prepare_text_queue_owner=owner.engine_prepare_text_queue_owner, + prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner, finalize_queue_owner=owner.engine_finalize_queue_owner, dispatch_queue_owner=owner.engine_dispatch_queue_owner, decode_runtime_owner=owner.engine_decode_runtime_owner, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py index b6c5ca4d..65953dd9 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py @@ -129,7 +129,7 @@ class EnginePolicyArbiterController: self.state.total_ticks += 1 if stage == "idle": self.state.total_idle_ticks += 1 - elif stage == "prepare": + elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}: self.state.total_prepare_dispatches += 1 self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) elif stage == "finalize": @@ -282,7 +282,11 @@ class EnginePolicyArbiterController: request_registry = self.snapshot_request_registry() worker_state = self.get_worker_state() policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) - prepare_waiting = int(self.snapshot_prepare_state().get("waiting_count", 0)) + prepare_state = self.snapshot_prepare_state() + prepare_waiting = int(prepare_state.get("waiting_count", 0)) + prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0)) + prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0)) + prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0)) finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) decode_runtime_state = self.snapshot_decode_runtime_state() @@ -291,6 +295,9 @@ class EnginePolicyArbiterController: worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio")) + prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text")) + prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec")) finalize_age_ms = float(self.peek_queue_age_ms("finalize")) decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) @@ -316,14 +323,31 @@ class EnginePolicyArbiterController: and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) ): return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if ( + finalize_waiting > 0 + and prepare_ref_spec_waiting > 0 + and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0) + ): + return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): return "finalize", "finalize_aging", policy_snapshot, worker_state if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): - return "prepare", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0: + return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): - return "prepare", "prepare_aging", policy_snapshot, worker_state + if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6): + return "prepare_text", "prepare_aging", policy_snapshot, worker_state + if ( + prepare_ref_spec_waiting > 0 + and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6) + ): + return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state + return "prepare_audio", "prepare_aging", policy_snapshot, worker_state if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state if decode_waiting > 0 and policy_allowed: @@ -331,5 +355,9 @@ class EnginePolicyArbiterController: if finalize_waiting > 0: return "finalize", "finalize_fallback", policy_snapshot, worker_state if prepare_waiting > 0: - return "prepare", "prepare_fallback", policy_snapshot, worker_state + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "prepare_fallback", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0: + return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state + return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state return "idle", "no_pending_work", policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py index 15eedeca..600e1e83 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -324,8 +324,24 @@ class EngineGpuPrepareTask: done_future: asyncio.Future | None engine_request_id: str | None enqueue_time: float - queue_wait_ms: float = 0.0 + phase: str = "audio" + audio_enqueue_time: float = 0.0 + audio_start_time: float = 0.0 + audio_end_time: float = 0.0 + text_enqueue_time: float = 0.0 + text_start_time: float = 0.0 + text_end_time: float = 0.0 + ref_spec_enqueue_time: float = 0.0 + ref_spec_start_time: float = 0.0 + ref_spec_end_time: float = 0.0 + audio_queue_wait_ms: float = 0.0 + text_queue_wait_ms: float = 0.0 + ref_spec_queue_wait_ms: float = 0.0 admission_wait_ms: float = 0.0 + phase_one: Dict[str, Any] | None = None + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None + state_result: T2SRequestState | None = None + cancelled: bool = False error: str | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py index a71f7e4e..0c73616f 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py @@ -14,6 +14,8 @@ class EngineStageOrchestrator: executor: EngineStageExecutor, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -22,6 +24,8 @@ class EngineStageOrchestrator: self.executor = executor self.scheduler_worker = scheduler_worker self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner self.finalize_queue_owner = finalize_queue_owner self.dispatch_queue_owner = dispatch_queue_owner self.decode_runtime_owner = decode_runtime_owner @@ -45,7 +49,17 @@ class EngineStageOrchestrator: def peek_queue_age_ms(self, queue_name: str) -> float: if queue_name == "prepare": + return max( + self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"), + ) + if queue_name == "prepare_audio": return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_text": + return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_ref_spec": + return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time") if queue_name == "finalize": return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") if queue_name == "decode_runtime_pending": @@ -62,6 +76,10 @@ class EngineStageOrchestrator: return True if self.prepare_queue_owner.has_items(): return True + if self.prepare_text_queue_owner.has_items(): + return True + if self.prepare_ref_spec_queue_owner.has_items(): + return True if self.finalize_queue_owner.has_items(): return True return self.dispatch_queue_owner.has_items() @@ -79,6 +97,12 @@ class EngineStageOrchestrator: executed = False if stage == "prepare": executed = self.executor.run_engine_prepare_once() + elif stage == "prepare_audio": + executed = self.executor.run_engine_prepare_audio_once() + elif stage == "prepare_text": + executed = self.executor.run_engine_prepare_text_once() + elif stage == "prepare_ref_spec": + executed = self.executor.run_engine_prepare_ref_spec_once() elif stage == "finalize": executed = self.executor.run_engine_finalize_once() elif stage == "decode_dispatch": diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py index 1b872dfa..27ed3bf5 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -24,6 +24,8 @@ class EngineStageCoordinator: tts: TTS, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -45,6 +47,8 @@ class EngineStageCoordinator: tts=tts, scheduler_worker=scheduler_worker, prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, finalize_queue_owner=finalize_queue_owner, dispatch_queue_owner=dispatch_queue_owner, decode_runtime_owner=decode_runtime_owner, @@ -66,6 +70,8 @@ class EngineStageCoordinator: executor=self.executor, scheduler_worker=scheduler_worker, prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, finalize_queue_owner=finalize_queue_owner, dispatch_queue_owner=dispatch_queue_owner, decode_runtime_owner=decode_runtime_owner, @@ -144,6 +150,15 @@ class EngineStageCoordinator: def run_engine_prepare_once(self) -> bool: return self.executor.run_engine_prepare_once() + def run_engine_prepare_audio_once(self) -> bool: + return self.executor.run_engine_prepare_audio_once() + + def run_engine_prepare_text_once(self) -> bool: + return self.executor.run_engine_prepare_text_once() + + def run_engine_prepare_ref_spec_once(self) -> bool: + return self.executor.run_engine_prepare_ref_spec_once() + def run_engine_finalize_once(self) -> bool: return self.executor.run_engine_finalize_once() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py index 644c35f6..f6a249fa 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py @@ -24,6 +24,10 @@ class EngineDispatchStageMixin: engine_request_id: str | None, timeout_sec: float | None, ) -> EngineDispatchTask: + if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + error = RuntimeError("ref_spec async stage failed before dispatch") + self.fail_request_state(engine_request_id or state.request_id, str(error)) + raise error task = EngineDispatchTask( request_id=state.request_id, state=state, diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py index 01921d51..6d06f0c6 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -31,6 +31,8 @@ class EngineStageExecutor( tts: TTS, scheduler_worker: UnifiedSchedulerWorker, prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, finalize_queue_owner: EngineTaskQueueOwner, dispatch_queue_owner: EngineTaskQueueOwner, decode_runtime_owner: EngineDecodeRuntimeOwner, @@ -51,6 +53,8 @@ class EngineStageExecutor( self.tts = tts self.scheduler_worker = scheduler_worker self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner self.finalize_queue_owner = finalize_queue_owner self.dispatch_queue_owner = dispatch_queue_owner self.decode_runtime_owner = decode_runtime_owner diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py index 8e66f76e..4b61993e 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py @@ -39,10 +39,37 @@ class EngineFinalizeStageMixin: tasks = self.take_engine_finalize_batch_nonblocking() if not tasks: return False - self.scheduler_worker.begin_finalize_execution(len(tasks)) + ready_tasks: List[SchedulerFinalizeTask] = [] + failed_tasks: List[SchedulerFinalizeTask] = [] + deferred_tasks: List[SchedulerFinalizeTask] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + failed_tasks.append(task) + continue + if job.state.refer_spec is None: + deferred_tasks.append(task) + self.merge_request_state_profile( + job.engine_request_id or job.request_id, + { + "engine_finalize_ref_spec_blocked": 1.0, + }, + ) + continue + ready_tasks.append(task) + if deferred_tasks: + self.finalize_queue_owner.enqueue_many(deferred_tasks) + if failed_tasks: + self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed") + if not ready_tasks: + self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True) + return False + self.scheduler_worker.begin_finalize_execution(len(ready_tasks)) try: jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] - for task in tasks: + for task in ready_tasks: job = self.get_engine_job(task.request_id) if job is None: continue @@ -50,7 +77,7 @@ class EngineFinalizeStageMixin: if not jobs_and_items: return False now = time.perf_counter() - for task in tasks: + for task in ready_tasks: job = self.get_engine_job(task.request_id) if job is not None: job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) @@ -69,8 +96,8 @@ class EngineFinalizeStageMixin: for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) except Exception as exc: - self.fail_engine_jobs([task.request_id for task in tasks], str(exc)) + self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc)) finally: - self.scheduler_worker.end_finalize_execution(len(tasks)) - self.finalize_queue_owner.mark_completed(len(tasks), notify=True) + self.scheduler_worker.end_finalize_execution(len(ready_tasks)) + self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True) return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py index b9095d2c..1ea7c45b 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -10,6 +10,13 @@ from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepare class EnginePrepareStageMixin: + def _prepare_waiting_total(self) -> int: + return ( + int(self.prepare_queue_owner.waiting_count()) + + int(self.prepare_text_queue_owner.waiting_count()) + + int(self.prepare_ref_spec_queue_owner.waiting_count()) + ) + async def _wait_prepare_queue_admission(self) -> float: soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) if soft_max <= 0: @@ -19,7 +26,7 @@ class EnginePrepareStageMixin: float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, ) wait_start = time.perf_counter() - while self.prepare_queue_owner.waiting_count() >= soft_max: + while self._prepare_waiting_total() >= soft_max: await asyncio.sleep(poll_s) return max(0.0, (time.perf_counter() - wait_start) * 1000.0) @@ -53,44 +60,247 @@ class EnginePrepareStageMixin: done_future=done_future, engine_request_id=engine_request_id or spec.request_id, enqueue_time=time.perf_counter(), + phase="audio", + audio_enqueue_time=time.perf_counter(), admission_wait_ms=float(prepare_queue_admission_wait_ms), ) self.prepare_queue_owner.enqueue(task) self.notify_arbiter() return await done_future - def run_engine_prepare_once(self) -> bool: - prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() - tasks = self.prepare_queue_owner.pop_left_many(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + def _should_chain_prepare_text_after_audio(self) -> bool: + if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}: + return False + if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items(): + return False + decode_runtime_state = self.snapshot_engine_decode_runtime_state() + if bool(decode_runtime_state.get("has_work", False)): + return False + return True + + def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None: + if task.state_result is None or task.ref_spec_result is None: + return + self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms), + "ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)), + }, + ) + + def _mark_ref_spec_async_failed( + self, + task: EngineGpuPrepareTask, + error: Exception, + *, + queue_wait_ms: float, + ) -> None: + task.error = str(error) + task.cancelled = True + if task.state_result is not None: + task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0 + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "ref_spec_async_failed": 1.0, + "engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms), + }, + ) + self.fail_request_state(task.engine_request_id or task.request_id, str(error)) + self.fail_engine_jobs([task.request_id], str(error)) + self.notify_arbiter() + + def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_queue_owner.pop_left_many(batch_max_items) if not tasks: return False now = time.perf_counter() queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] - batch_results = asyncio.run( - self.scheduler_worker.prepare_gpu_stages_profiled_async([task.cpu_stage for task in tasks]) - ) + for task in tasks: + task.audio_start_time = float(now) + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks])) completed_count = 0 for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.audio_end_time = time.perf_counter() if isinstance(result, Exception): task.error = str(result) self.fail_request_state(task.engine_request_id or task.request_id, str(result)) self._notify_prepare_error(task, result) completed_count += 1 continue - state, prepare_exec_started_at, prepare_exec_finished_at = result - state.prepare_profile["engine_prepare_queue_admission_wait_ms"] = float(task.admission_wait_ms) - state.prepare_profile["engine_gpu_prepare_queue_wait_ms"] = float(queue_wait_ms) + task.audio_queue_wait_ms = float(queue_wait_ms) + task.phase_one = result + task.phase = "text" + task.enqueue_time = time.perf_counter() + task.text_enqueue_time = float(task.enqueue_time) + task.ref_spec_enqueue_time = float(task.enqueue_time) + self.prepare_text_queue_owner.enqueue(task) + self.prepare_ref_spec_queue_owner.enqueue(task) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + completed_count += 1 + self.prepare_queue_owner.mark_completed(completed_count) + if completed_count > 0 and self._should_chain_prepare_text_after_audio(): + self._run_engine_prepare_text_once(min(batch_max_items, completed_count)) + return True + if completed_count > 0: + self.notify_arbiter() + return True + + def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + for task in tasks: + task.text_start_time = float(now) + items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None] + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items)) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.text_end_time = time.perf_counter() + if isinstance(result, Exception): + task.error = str(result) + task.cancelled = True + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + task.text_queue_wait_ms = float(queue_wait_ms) + state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases( + task.cpu_stage, + task.phase_one or {}, + result, + extra_profile={ + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_text_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_mode": 2.0, + "engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_batch_size": float(len(tasks)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_text_start_ts": float(task.text_start_time), + "engine_prepare_text_end_ts": float(task.text_end_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + task.state_result = state + self._maybe_apply_ref_spec_to_state(task) state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) if task.engine_request_id not in [None, ""]: self.merge_request_state_profile( str(task.engine_request_id), { "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), - "engine_gpu_prepare_queue_wait_ms": float(queue_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), "engine_gpu_prepare_batch_size": float(len(tasks)), }, ) self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) completed_count += 1 - self.prepare_queue_owner.mark_completed(completed_count) + self.prepare_text_queue_owner.mark_completed(completed_count) return True + + def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + runnable_tasks: list[EngineGpuPrepareTask] = [] + queue_wait_ms_list: list[float] = [] + completed_count = 0 + for task in tasks: + if task.cancelled or task.phase_one is None: + completed_count += 1 + continue + task.ref_spec_start_time = float(now) + runnable_tasks.append(task) + queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0)) + if not runnable_tasks: + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + batch_results = asyncio.run( + self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks]) + ) + for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results): + task.ref_spec_end_time = time.perf_counter() + task.ref_spec_queue_wait_ms = float(queue_wait_ms) + if isinstance(result, Exception): + self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms)) + completed_count += 1 + continue + task.ref_spec_result = result + self._maybe_apply_ref_spec_to_state(task) + if task.state_result is not None: + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time) + completed_count += 1 + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + + def run_engine_prepare_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1)) + audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if self.prepare_text_queue_owner.has_items() and ( + not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms + ): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_queue_owner.has_items(): + return self._run_engine_prepare_audio_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + if self.prepare_text_queue_owner.has_items(): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + return False + + def run_engine_prepare_audio_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_text_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_ref_spec_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py index 3a675cbe..bb9cb3cb 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -151,7 +151,9 @@ class WorkerFinalizeExecutor: @staticmethod def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]: - refer_specs = [job.state.refer_spec] + refer_specs = [] + if job.state.refer_spec is not None: + refer_specs.append(job.state.refer_spec) refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or [])) return refer_specs diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py index 8b3db1fa..9fb7c8d9 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio import os import time -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List from GPT_SoVITS.TTS_infer_pack.TTS import TTS from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage @@ -81,11 +81,60 @@ class WorkerPrepareExecutor: cpu_stages: List[PreparedCpuStage], ) -> List[tuple[T2SRequestState, float, float] | Exception]: try: - return list( - await asyncio.gather( - *[self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) for cpu_stage in cpu_stages], - return_exceptions=True, - ) + return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_text_phases_async(items) + finally: + self._notify_state_change() + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self.coordinator.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, ) finally: self._notify_state_change() + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + try: + return await self.coordinator.prepare_ref_spec_stages_async(phase_ones) + finally: + self._notify_state_change() + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + try: + self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py index e498e9ea..2ac636fe 100644 --- a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -267,3 +267,42 @@ class WorkerSubmitLifecycleMixin: cpu_stages: List[PreparedCpuStage], ) -> List[tuple[T2SRequestState, float, float] | Exception]: return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages) + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_text_phases_async(items) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + return self.prepare_executor.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, + ) + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result) diff --git a/GPT_SoVITS/text/g2pw/cuda_api.py b/GPT_SoVITS/text/g2pw/cuda_api.py index e1a84748..881d6123 100644 --- a/GPT_SoVITS/text/g2pw/cuda_api.py +++ b/GPT_SoVITS/text/g2pw/cuda_api.py @@ -244,6 +244,16 @@ class G2PWRuntimeWrapper: ) self.batch_worker.start() + def _sync_runtime_env_overrides(self) -> None: + os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0" + os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0" + os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0" + os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit)) + os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit)) + os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0" + os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0" + os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32") + def _destroy_handle(self) -> None: if self.handle: self.lib.g2pw_runtime_destroy(self.handle) @@ -268,6 +278,7 @@ class G2PWRuntimeWrapper: return "" if not message else message.decode("utf-8", errors="replace") def _create_handle(self, batch_size: int, seq_len: int) -> None: + self._sync_runtime_env_overrides() new_handle = self.lib.g2pw_runtime_create( str(self.manifest_path).encode("utf-8"), str(self.weights_path).encode("utf-8"), @@ -518,6 +529,10 @@ class G2PWRuntimeWrapper: return { "shard_index": int(self.shard_index), "enabled": bool(self.batch_enabled), + "enable_cuda_graph": bool(self.enable_cuda_graph), + "enable_profiling": bool(self.enable_profiling), + "full_graph_cache_limit": int(self.full_graph_cache_limit), + "tail_graph_cache_limit": int(self.tail_graph_cache_limit), "window_ms": float(self.batch_window_s * 1000.0), "max_requests": int(self.batch_max_requests), "max_rows": int(self.batch_max_rows),