From 800acd45ff554dd63f8de365e111c625c8923be6 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sat, 7 Mar 2026 05:47:22 +0800 Subject: [PATCH 1/9] Enhance G2P processing by implementing batch input handling in _g2p function, improving efficiency. Update prepare_onnx_input to utilize caching for tokenization and add optional parameters for character ID mapping and phoneme masks. Refactor G2PWOnnxConverter to streamline model loading and configuration management. --- GPT_SoVITS/text/chinese2.py | 17 ++- GPT_SoVITS/text/g2pw/dataset.py | 107 ++++++++++++++---- GPT_SoVITS/text/g2pw/onnx_api.py | 181 +++++++++++++++++++++++-------- 3 files changed, 233 insertions(+), 72 deletions(-) diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index dcce0d96..acfebfe2 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -180,10 +180,15 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> def _g2p(segments): phones_list = [] word2ph = [] - for seg in segments: + g2pw_batch_results = [] + g2pw_batch_cursor = 0 + processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments] + if is_g2pw: + batch_inputs = [seg for seg in processed_segments if seg] + g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else [] + + for seg in processed_segments: pinyins = [] - # Replace all English words in the sentence - seg = re.sub("[a-zA-Z]+", "", seg) seg_cut = psg.lcut(seg) seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) initials = [] @@ -204,8 +209,10 @@ def _g2p(segments): finals = sum(finals, []) print("pypinyin结果", initials, finals) else: - # g2pw采用整句推理 - pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3) + # g2pw采用整句推理(批量推理,逐句取结果) + if seg: + pinyins = g2pw_batch_results[g2pw_batch_cursor] + g2pw_batch_cursor += 1 pre_word_length = 0 for word, pos in seg_cut: diff --git a/GPT_SoVITS/text/g2pw/dataset.py b/GPT_SoVITS/text/g2pw/dataset.py index ff09cbc2..e464c29a 100644 --- a/GPT_SoVITS/text/g2pw/dataset.py +++ b/GPT_SoVITS/text/g2pw/dataset.py @@ -18,6 +18,7 @@ Credits from typing import Dict from typing import List +from typing import Optional from typing import Tuple import numpy as np @@ -37,6 +38,8 @@ def prepare_onnx_input( use_mask: bool = False, window_size: int = None, max_len: int = 512, + char2id: Optional[Dict[str, int]] = None, + char_phoneme_masks: Optional[Dict[str, List[int]]] = None, ) -> Dict[str, np.array]: if window_size is not None: truncated_texts, truncated_query_ids = _truncate_texts( @@ -48,33 +51,88 @@ def prepare_onnx_input( phoneme_masks = [] char_ids = [] position_ids = [] + tokenized_cache = {} + + if char2id is None: + char2id = {char: idx for idx, char in enumerate(chars)} + if use_mask: + if char_phoneme_masks is None: + char_phoneme_masks = { + char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))] + for char in char2phonemes + } + else: + full_phoneme_mask = [1] * len(labels) for idx in range(len(texts)): text = (truncated_texts if window_size else texts)[idx].lower() query_id = (truncated_query_ids if window_size else query_ids)[idx] - try: - tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) - except Exception: - print(f'warning: text "{text}" is invalid') - return {} + cached = tokenized_cache.get(text) + if cached is None: + try: + tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} - text, query_id, tokens, text2token, token2text = _truncate( - max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text - ) + if len(tokens) <= max_len - 2: + processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) + cached = { + "is_short": True, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + "input_id": shared_input_id, + "token_type_id": shared_token_type_id, + "attention_mask": shared_attention_mask, + } + else: + cached = { + "is_short": False, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + } + tokenized_cache[text] = cached - processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + if cached["is_short"]: + text_for_query = text + query_id_for_query = query_id + text2token_for_query = cached["text2token"] + input_id = cached["input_id"] + token_type_id = cached["token_type_id"] + attention_mask = cached["attention_mask"] + else: + ( + text_for_query, + query_id_for_query, + tokens_for_query, + text2token_for_query, + _token2text_for_query, + ) = _truncate( + max_len=max_len, + text=text, + query_id=query_id, + tokens=cached["tokens"], + text2token=cached["text2token"], + token2text=cached["token2text"], + ) + processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"] + input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) - token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) - attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - - query_char = text[query_id] - phoneme_mask = ( - [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels) - ) - char_id = chars.index(query_char) - position_id = text2token[query_id] + 1 # [CLS] token locate at first place + query_char = text_for_query[query_id_for_query] + if use_mask: + phoneme_mask = char_phoneme_masks[query_char] + else: + phoneme_mask = full_phoneme_mask + char_id = char2id[query_char] + position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place input_ids.append(input_id) token_type_ids.append(token_type_id) @@ -83,10 +141,15 @@ def prepare_onnx_input( char_ids.append(char_id) position_ids.append(position_id) + max_token_length = max(len(seq) for seq in input_ids) + + def _pad_sequences(sequences, pad_value=0): + return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences] + outputs = { - "input_ids": np.array(input_ids).astype(np.int64), - "token_type_ids": np.array(token_type_ids).astype(np.int64), - "attention_masks": np.array(attention_masks).astype(np.int64), + "input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64), + "token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64), + "attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64), "phoneme_masks": np.array(phoneme_masks).astype(np.float32), "char_ids": np.array(char_ids).astype(np.int64), "position_ids": np.array(position_ids).astype(np.int64), diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 1d5e4231..5fcf1ae2 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple import numpy as np import onnxruntime import requests -import torch from opencc import OpenCC from pypinyin import Style, pinyin from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -22,9 +21,8 @@ from .utils import load_config onnxruntime.set_default_logger_severity(3) try: onnxruntime.preload_dlls() -except: +except Exception: pass - # traceback.print_exc() warnings.filterwarnings("ignore") model_version = "1.1" @@ -55,6 +53,24 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis return all_preds, all_confidences +def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]: + for candidate_dir in candidate_dirs: + if not candidate_dir: + continue + json_path = os.path.join(candidate_dir, filename) + if os.path.exists(json_path): + with open(json_path, "r", encoding="utf-8") as fr: + return json.load(fr) + raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}") + + +def _find_first_existing_file(*paths: str) -> str: + for path in paths: + if path and os.path.exists(path): + return path + raise FileNotFoundError(f"Files not found: {paths}") + + def download_and_decompress(model_dir: str = "G2PWModel/"): if not os.path.exists(model_dir): parent_directory = os.path.dirname(model_dir) @@ -62,7 +78,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): extract_dir = os.path.join(parent_directory, "G2PWModel_1.1") extract_dir_new = os.path.join(parent_directory, "G2PWModel") print("Downloading g2pw model...") - modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip" + modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" with requests.get(modelscope_url, stream=True) as r: r.raise_for_status() with open(zip_dir, "wb") as f: @@ -79,7 +95,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): return model_dir -class G2PWOnnxConverter: +class _G2PWBaseOnnxConverter: def __init__( self, model_dir: str = "G2PWModel/", @@ -87,33 +103,16 @@ class G2PWOnnxConverter: model_source: str = None, enable_non_tradional_chinese: bool = False, ): - uncompress_path = download_and_decompress(model_dir) - - sess_options = onnxruntime.SessionOptions() - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL - sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0 - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], - ) - else: - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CPUExecutionProvider"], - ) - self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True) + self.model_dir = download_and_decompress(model_dir) + self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True) self.model_source = model_source if model_source else self.config.model_source self.enable_opencc = enable_non_tradional_chinese - self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) - polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt") - monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt") + polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt") + monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt") + self.polyphonic_chars = [ line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n") ] @@ -149,31 +148,45 @@ class G2PWOnnxConverter: ) self.chars = sorted(list(self.char2phonemes.keys())) + self.char2id = {char: idx for idx, char in enumerate(self.chars)} + self.char_phoneme_masks = ( + { + char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))] + for char in self.char2phonemes + } + if self.config.use_mask + else None + ) self.polyphonic_chars_new = set(self.chars) for char in self.non_polyphonic: - if char in self.polyphonic_chars_new: - self.polyphonic_chars_new.remove(char) + self.polyphonic_chars_new.discard(char) self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars} for char in self.non_monophonic: - if char in self.monophonic_chars_dict: - self.monophonic_chars_dict.pop(char) + self.monophonic_chars_dict.pop(char, None) - self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"] + default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel")) + candidate_asset_dirs = [self.model_dir, default_asset_dir] + self.bopomofo_convert_dict = _load_json_from_candidates( + "bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs + ) + self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs) - with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr: - self.bopomofo_convert_dict = json.load(fr) self.style_convert_func = { "bopomofo": lambda x: x, "pinyin": self._convert_bopomofo_to_pinyin, }[style] - with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr: - self.char_bopomofo_dict = json.load(fr) - if self.enable_opencc: self.cc = OpenCC("s2tw") + self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in { + "1", + "true", + "yes", + "y", + "on", + } def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] @@ -181,9 +194,8 @@ class G2PWOnnxConverter: component = self.bopomofo_convert_dict.get(bopomofo[:-1]) if component: return component + tone - else: - print(f'Warning: "{bopomofo}" cannot convert to pinyin') - return None + print(f'Warning: "{bopomofo}" cannot convert to pinyin') + return None def __call__(self, sentences: List[str]) -> List[List[str]]: if isinstance(sentences, str): @@ -199,10 +211,9 @@ class G2PWOnnxConverter: texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: - # sentences no polyphonic words return partial_results - onnx_input = prepare_onnx_input( + model_input = prepare_onnx_input( tokenizer=self.tokenizer, labels=self.labels, char2phonemes=self.char2phonemes, @@ -211,9 +222,18 @@ class G2PWOnnxConverter: query_ids=query_ids, use_mask=self.config.use_mask, window_size=None, + char2id=self.char2id, + char_phoneme_masks=self.char_phoneme_masks, ) - preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels) + if not model_input: + return partial_results + + if self.enable_sentence_dedup: + preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts) + else: + preds, _confidences = self._predict(model_input=model_input) + if self.config.use_char_phoneme: preds = [pred.split(" ")[1] for pred in preds] @@ -226,7 +246,6 @@ class G2PWOnnxConverter: def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: texts, query_ids, sent_ids, partial_results = [], [], [], [] for sent_id, sent in enumerate(sentences): - # pypinyin works well for Simplified Chinese than Traditional Chinese sent_s = tranditional_to_simplified(sent) pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) partial_result = [None] * len(sent) @@ -239,9 +258,81 @@ class G2PWOnnxConverter: partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: partial_result[i] = pypinyin_result[i][0] - # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) else: partial_result[i] = pypinyin_result[i][0] partial_results.append(partial_result) return texts, query_ids, sent_ids, partial_results + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + raise NotImplementedError + + def _predict_with_sentence_dedup( + self, model_input: Dict[str, Any], texts: List[str] + ) -> Tuple[List[str], List[float]]: + if len(texts) <= 1: + return self._predict(model_input=model_input) + + grouped_indices: Dict[str, List[int]] = {} + for idx, text in enumerate(texts): + grouped_indices.setdefault(text, []).append(idx) + + if all(len(indices) == 1 for indices in grouped_indices.values()): + return self._predict(model_input=model_input) + + preds: List[str] = [""] * len(texts) + confidences: List[float] = [0.0] * len(texts) + for indices in grouped_indices.values(): + group_input = {name: value[indices] for name, value in model_input.items()} + if len(indices) > 1: + for name in ("input_ids", "token_type_ids", "attention_masks"): + group_input[name] = group_input[name][:1] + + group_preds, group_confidences = self._predict(model_input=group_input) + for output_idx, pred, confidence in zip(indices, group_preds, group_confidences): + preds[output_idx] = pred + confidences[output_idx] = confidence + + return preds, confidences + + +class G2PWOnnxConverter(_G2PWBaseOnnxConverter): + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + super().__init__( + model_dir=model_dir, + style=style, + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL + sess_options.intra_op_num_threads = 2 + + onnx_path = _find_first_existing_file( + os.path.join(self.model_dir, "g2pW.onnx"), + os.path.join(self.model_dir, "g2pw.onnx"), + ) + + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + else: + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels) From b250e62402893b4f355d58599cd946150880ef7f Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sun, 8 Mar 2026 03:01:20 +0800 Subject: [PATCH 2/9] Enhance G2PW model input handling by introducing polyphonic context character support and updating the data preparation method to return additional query IDs. This improves the processing of polyphonic characters in sentences. --- GPT_SoVITS/text/g2pw/onnx_api.py | 37 ++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 5fcf1ae2..3c2b0169 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -187,6 +187,8 @@ class _G2PWBaseOnnxConverter: "y", "on", } + # 聚焦到多音字附近上下文,默认左右各16字;设为0表示关闭裁剪(整句)。 + self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16"))) def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] @@ -209,7 +211,7 @@ class _G2PWBaseOnnxConverter: translated_sentences.append(translated_sent) sentences = translated_sentences - texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) + texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: return partial_results @@ -219,7 +221,7 @@ class _G2PWBaseOnnxConverter: char2phonemes=self.char2phonemes, chars=self.chars, texts=texts, - query_ids=query_ids, + query_ids=model_query_ids, use_mask=self.config.use_mask, window_size=None, char2id=self.char2id, @@ -238,22 +240,23 @@ class _G2PWBaseOnnxConverter: preds = [pred.split(" ")[1] for pred in preds] results = partial_results - for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): + for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds): results[sent_id][query_id] = self.style_convert_func(pred) return results - def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: - texts, query_ids, sent_ids, partial_results = [], [], [], [] + def _prepare_data( + self, sentences: List[str] + ) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]: + texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], [] for sent_id, sent in enumerate(sentences): sent_s = tranditional_to_simplified(sent) pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) partial_result = [None] * len(sent) + polyphonic_indices: List[int] = [] for i, char in enumerate(sent): if char in self.polyphonic_chars_new: - texts.append(sent) - query_ids.append(i) - sent_ids.append(sent_id) + polyphonic_indices.append(i) elif char in self.monophonic_chars_dict: partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: @@ -261,8 +264,24 @@ class _G2PWBaseOnnxConverter: else: partial_result[i] = pypinyin_result[i][0] + if polyphonic_indices: + if self.polyphonic_context_chars > 0: + left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars) + right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1) + sent_for_predict = sent[left:right] + query_offset = left + else: + sent_for_predict = sent + query_offset = 0 + + for index in polyphonic_indices: + texts.append(sent_for_predict) + model_query_ids.append(index - query_offset) + result_query_ids.append(index) + sent_ids.append(sent_id) + partial_results.append(partial_result) - return texts, query_ids, sent_ids, partial_results + return texts, model_query_ids, result_query_ids, sent_ids, partial_results def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: raise NotImplementedError From 30a4557d8db2ebbac91dc264a063cc7175172932 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sun, 8 Mar 2026 23:08:27 +0800 Subject: [PATCH 3/9] Implement last inference statistics tracking in Text2SemanticDecoder and enhance TTS class with prompt semantic extraction. This includes methods for setting and retrieving inference stats, as well as improvements to audio processing and feature extraction in TTS. --- GPT_SoVITS/AR/models/t2s_model.py | 76 +++++++++++++++++++++++++++++-- GPT_SoVITS/TTS_infer_pack/TTS.py | 38 ++++++++++++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ac905f4b..f55b7508 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module): blocks.append(block) self.t2s_transformer = T2STransformer(self.num_layers, blocks) + self.last_infer_stats = {} + + def _set_last_infer_stats(self, stats): + self.last_infer_stats = stats + + def get_last_infer_stats(self): + return dict(self.last_infer_stats) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): + requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True)) if prompts is None: + self._set_last_infer_stats( + { + "infer_mode": "batch_infer_prompt_free_fallback", + "requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": 0, + "generated_token_count_list": [], + } + ) print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") return self.infer_panel_naive_batched( x, @@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module): ) max_len = kwargs.get("max_len", x_lens.max()) + enable_mask_free_fastpath = requested_enable_mask_free_fastpath x_list = [] for x_item, bert_item in zip(x, bert_feature): # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) @@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module): y_list = [None] * y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None] * y.shape[0] + decode_attn_mask = attn_mask + prefill_after_mask_all_visible = None + fastpath_hit = False for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token( + xy_pos, k_cache, v_cache, decode_attn_mask + ) logits = self.ar_predict_layer(xy_dec[:, -1]) if idx == 0: attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False) + prefill_after_mask_all_visible = not attn_mask.any().item() + if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible: + decode_attn_mask = None + fastpath_hit = True + else: + decode_attn_mask = attn_mask else: - attn_mask = F.pad(attn_mask, (0, 1), value=False) + if decode_attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1), value=False) + decode_attn_mask = attn_mask if idx < 11: ###至少预测出10个token不然不给停止(0.4s) logits = logits[:, :-1] @@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module): if reserved_idx_of_batch_for_y is not None: # index = torch.LongTensor(batch_idx_map).to(y.device) y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + if decode_attn_mask is not None: + attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + decode_attn_mask = attn_mask if k_cache is not None: for i in range(len(k_cache)): k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) @@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module): if idx_list[i] is None: idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替 + self._set_last_infer_stats( + { + "infer_mode": "batch_infer", + "requested_enable_mask_free_fastpath": enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": prefill_after_mask_all_visible, + "fastpath_hit": fastpath_hit, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + "max_len": int(max_len), + } + ) if ref_free: return y_list, [0] * x.shape[0] # print(idx_list) @@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module): y_list.append(y[0]) idx_list.append(idx) + self._set_last_infer_stats( + { + "infer_mode": "naive_batched", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + } + ) return y_list, idx_list def infer_panel_naive( @@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module): if not streaming_mode: + generated_token_count = max(int(y.shape[1] - prefix_len), 0) + self._set_last_infer_stats( + { + "infer_mode": "naive", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(x.shape[0]), + "prefill_after_mask_all_visible": True if prompts is not None else None, + "fastpath_hit": True if prompts is not None else False, + "generated_token_count": generated_token_count, + "generated_token_count_list": [generated_token_count], + } + ) if ref_free: yield y, 0 yield y, idx diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c8344b0..e1efd973 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1227,6 +1227,9 @@ class TTS: ###### inference ###### t_34 = 0.0 t_45 = 0.0 + t2s_observe_batch_count = 0 + t2s_observe_fastpath_hits = 0 + t2s_observe_generated_tokens = 0 audio = [] is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] @@ -1280,6 +1283,29 @@ class TTS: ) t4 = time.perf_counter() t_34 += t4 - t3 + if hasattr(self.t2s_model.model, "get_last_infer_stats"): + t2s_stats = self.t2s_model.model.get_last_infer_stats() + if t2s_stats: + generated_token_count = int(t2s_stats.get("generated_token_count", 0)) + t2s_total_ms = (t4 - t3) * 1000.0 + avg_decode_ms_per_token = ( + t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0 + ) + t2s_observe_batch_count += 1 + t2s_observe_generated_tokens += generated_token_count + if bool(t2s_stats.get("fastpath_hit", False)): + t2s_observe_fastpath_hits += 1 + print( + "[t2s_observe] " + f"mode={t2s_stats.get('infer_mode')} " + f"batch_size={t2s_stats.get('batch_size')} " + f"tokens={generated_token_count} " + f"t2s_ms={t2s_total_ms:.3f} " + f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} " + f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} " + f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} " + f"fastpath_hit={t2s_stats.get('fastpath_hit')}" + ) batch_audio_fragment = [] @@ -1500,6 +1526,18 @@ class TTS: if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + if t2s_observe_batch_count > 0: + request_avg_decode_ms_per_token = ( + (t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0 + ) + print( + "[t2s_request_observe] " + f"batches={t2s_observe_batch_count} " + f"fastpath_hits={t2s_observe_fastpath_hits} " + f"generated_tokens={t2s_observe_generated_tokens} " + f"t2s_total_ms={t_34 * 1000.0:.3f} " + f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}" + ) if len(audio) == 0: yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return From dc37b0b9ef9f29b852fa37e6cd661369b60d6f86 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 00:22:59 +0800 Subject: [PATCH 4/9] Add WebAPI documentation and implement TTS API with endpoints for text-to-speech inference, control commands, and model switching. Enhance TTS class with methods for extracting prompt semantics and reference audio specifications. Introduce a scheduler prototype for managing T2S requests. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 190 +++- GPT_SoVITS/TTS_infer_pack/__init__.py | 12 +- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 631 +++++++++++ api_v3.py | 1170 ++++++++++++++++++++ tools/t2s_scheduler_prototype.py | 180 +++ 5 files changed, 2147 insertions(+), 36 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py create mode 100644 api_v3.py create mode 100644 tools/t2s_scheduler_prototype.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index e1efd973..bd4953df 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -759,21 +759,35 @@ class TTS: self._set_ref_spec(ref_audio_path) self._set_ref_audio_path(ref_audio_path) - def _set_ref_audio_path(self, ref_audio_path): - self.prompt_cache["ref_audio_path"] = ref_audio_path + def extract_prompt_semantic(self, ref_wav_path: str): + zero_wav = np.zeros( + int(self.configs.sampling_rate * 0.3), + dtype=np.float16 if self.configs.is_half else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(self.configs.device) + zero_wav_torch = zero_wav_torch.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + zero_wav_torch = zero_wav_torch.half() - def _set_ref_spec(self, ref_audio_path): - spec_audio = self._get_ref_spec(ref_audio_path) - if self.prompt_cache["refer_spec"] in [[], None]: - self.prompt_cache["refer_spec"] = [spec_audio] - else: - self.prompt_cache["refer_spec"][0] = spec_audio + wav16k = torch.cat([wav16k, zero_wav_torch]) + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( + 1, 2 + ) # .float() + codes = self.vits_model.extract_latent(hubert_feature) - def _get_ref_spec(self, ref_audio_path): + prompt_semantic = codes[0, 0].to(self.configs.device) + return prompt_semantic + + def extract_ref_spec(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) raw_audio = raw_audio.to(self.configs.device).float() - self.prompt_cache["raw_audio"] = raw_audio - self.prompt_cache["raw_sr"] = raw_sr if raw_sr != self.configs.sampling_rate: audio = raw_audio.to(self.configs.device) @@ -804,33 +818,30 @@ class TTS: audio = audio.half() else: audio = None + return spec, audio, raw_audio, raw_sr + + def extract_text_features(self, text: str, language: str): + return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + + def _set_ref_audio_path(self, ref_audio_path): + self.prompt_cache["ref_audio_path"] = ref_audio_path + + def _set_ref_spec(self, ref_audio_path): + spec_audio = self._get_ref_spec(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [spec_audio] + else: + self.prompt_cache["refer_spec"][0] = spec_audio + + def _get_ref_spec(self, ref_audio_path): + spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path) + self.prompt_cache["raw_audio"] = raw_audio + self.prompt_cache["raw_sr"] = raw_sr return spec, audio def _set_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - self.prompt_cache["prompt_semantic"] = prompt_semantic + prompt_semantic = self.extract_prompt_semantic(ref_wav_path) + self.prompt_cache["prompt_semantic"] = prompt_semantic def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): seq = sequences[0] @@ -1701,6 +1712,115 @@ class TTS: return audio + def using_vocoder_synthesis_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_audio_spec: torch.Tensor, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device) + refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device) + + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio = raw_audio.to(self.configs.device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + tgt_sr = 24000 if self.configs.version == "v3" else 32000 + if raw_sr != tgt_sr: + ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device) + + mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4 + mel2 = mel_spec_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + T_ref = self.vocoder_configs["T_ref"] + T_chunk = self.vocoder_configs["T_chunk"] + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min + + mel2 = mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + + cfm_res = self.vits_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + + with torch.inference_mode(): + wav_gen = self.vocoder(cfm_res) + audio = wav_gen[0][0] + + return audio + + def synthesize_audio_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_spec: tuple, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + refer_audio_spec, audio_tensor = refer_spec + if not self.configs.use_vocoder: + refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)] + sv_emb = None + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) + sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device) + return self.vits_model.decode( + semantic_tokens, + phones, + refer_audio_spec_list, + speed=speed, + sv_emb=sv_emb, + ).detach()[0, 0, :] + + return self.using_vocoder_synthesis_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_audio_spec=refer_audio_spec, + raw_audio=raw_audio, + raw_sr=raw_sr, + speed=speed, + sample_steps=sample_steps, + ) + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/__init__.py b/GPT_SoVITS/TTS_infer_pack/__init__.py index 8579a632..09a257b2 100644 --- a/GPT_SoVITS/TTS_infer_pack/__init__.py +++ b/GPT_SoVITS/TTS_infer_pack/__init__.py @@ -1 +1,11 @@ -from . import TTS, text_segmentation_method +from __future__ import annotations + +import importlib + +__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"] + + +def __getattr__(name: str): + if name in __all__: + return importlib.import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py new file mode 100644 index 00000000..e94a72c7 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -0,0 +1,631 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F + +from AR.models.utils import make_pad_mask_left, sample + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +@dataclass +class SchedulerRequestSpec: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int = 0 + + +@dataclass +class T2SRequestState: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + norm_prompt_text: str + norm_text: str + phones: torch.LongTensor + prompt_phones: torch.LongTensor + all_phones: torch.LongTensor + all_bert_features: torch.Tensor + prompt_semantic: torch.LongTensor + refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + raw_audio: torch.Tensor + raw_sr: int + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int + prepare_profile: Dict[str, float] + + +@dataclass +class T2SRunningRequest: + state: T2SRequestState + y_sequence: torch.LongTensor + prefix_len: int + decode_attn_mask: torch.Tensor + k_cache: List[torch.Tensor] + v_cache: List[torch.Tensor] + step_idx: int + + +@dataclass +class T2SFinishedItem: + request_id: str + semantic_tokens: torch.LongTensor + finish_idx: int + finish_reason: str + + +@dataclass +class T2SActiveBatch: + request_ids: List[str] + states: List[T2SRequestState] + x: torch.Tensor + x_lens: torch.LongTensor + y_sequences: List[torch.LongTensor] + prefix_lens: torch.LongTensor + xy_pos: torch.Tensor + prefill_attn_mask: torch.Tensor + decode_attn_mask: Optional[torch.Tensor] + k_cache: Optional[List[torch.Tensor]] + v_cache: Optional[List[torch.Tensor]] + step_idx: int + prefill_done: bool + + +def normalize_sentence(text: str, language: str) -> str: + text = text.strip("\n").strip() + if not text: + return text + if text[-1] not in {",", ".", "?", "!", ",", "。", "?", "!", "…", ";", ";", ":"}: + text += "。" if language != "en" else "." + return text + + +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + device = tts.configs.device + prepare_start = time.perf_counter() + _sync_device(device) + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + + _sync_device(device) + text_features_start = time.perf_counter() + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + _sync_device(device) + text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + if phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + + _sync_device(device) + prompt_semantic_start = time.perf_counter() + prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + _sync_device(device) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + + _sync_device(device) + ref_spec_start = time.perf_counter() + spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) + _sync_device(device) + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + + _sync_device(device) + tensorize_start = time.perf_counter() + phones_tensor = torch.LongTensor(phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device) + all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to( + dtype=tts.precision, device=tts.configs.device + ) + _sync_device(device) + tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 + + _sync_device(device) + prepare_profile = { + "prompt_text_features_ms": prompt_text_features_ms, + "text_features_ms": text_features_ms, + "prompt_semantic_ms": prompt_semantic_ms, + "ref_spec_ms": ref_spec_ms, + "tensorize_ms": tensorize_ms, + "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, + "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, + } + return T2SRequestState( + request_id=spec.request_id, + ref_audio_path=spec.ref_audio_path, + prompt_text=prompt_text, + prompt_lang=spec.prompt_lang, + text=text, + text_lang=spec.text_lang, + norm_prompt_text=prompt_norm_text, + norm_text=norm_text, + phones=phones_tensor, + prompt_phones=prompt_phones_tensor, + all_phones=all_phones, + all_bert_features=all_bert_features, + prompt_semantic=prompt_semantic, + refer_spec=(spec_audio, audio_16k), + raw_audio=raw_audio, + raw_sr=int(raw_sr), + top_k=spec.top_k, + top_p=spec.top_p, + temperature=spec.temperature, + repetition_penalty=spec.repetition_penalty, + early_stop_num=spec.early_stop_num, + ready_step=spec.ready_step, + prepare_profile=prepare_profile, + ) + + +def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: + if hidden.shape[0] >= target_len: + return hidden + return F.pad(hidden, (0, 0, target_len - hidden.shape[0], 0), value=0) + + +def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: torch.device) -> None: + required_len = max_position + 1 + if model.ar_audio_position.pe is not None and model.ar_audio_position.pe.size(1) >= required_len: + if model.ar_audio_position.pe.dtype != dtype or model.ar_audio_position.pe.device != device: + model.ar_audio_position.pe = model.ar_audio_position.pe.to(dtype=dtype, device=device) + return + model.ar_audio_position.extend_pe( + torch.zeros(1, required_len, model.ar_audio_position.embedding_dim, device=device, dtype=dtype) + ) + + +def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: + x_items: List[torch.Tensor] = [] + y_pos_items: List[torch.Tensor] = [] + x_lens: List[int] = [] + prefix_lens: List[int] = [] + y_sequences: List[torch.LongTensor] = [] + + for state in states: + text_emb = model.ar_text_embedding(state.all_phones.unsqueeze(0)) + bert_proj = model.bert_proj(state.all_bert_features.transpose(0, 1).unsqueeze(0)) + x_pos = model.ar_text_position(text_emb + bert_proj).squeeze(0) + y_emb = model.ar_audio_embedding(state.prompt_semantic.unsqueeze(0)) + y_pos = model.ar_audio_position(y_emb).squeeze(0) + x_items.append(x_pos) + y_pos_items.append(y_pos) + x_lens.append(x_pos.shape[0]) + prefix_lens.append(y_pos.shape[0]) + y_sequences.append(state.prompt_semantic.clone()) + + max_x_len = max(x_lens) + max_prefix_len = max(prefix_lens) + x_batch = torch.stack([_left_pad_hidden(item, max_x_len) for item in x_items], dim=0) + y_pos_batch = torch.stack([_left_pad_hidden(item, max_prefix_len) for item in y_pos_items], dim=0) + xy_pos = torch.cat([x_batch, y_pos_batch], dim=1) + + device = x_batch.device + x_lens_tensor = torch.LongTensor(x_lens).to(device) + prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) + batch_size = len(states) + src_len = max_x_len + max_prefix_len + + x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) + y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) + padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1) + x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) + y_mask = F.pad( + torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), + (max_x_len, 0), + value=False, + ) + causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1) + padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1) + attn_mask = causal_mask.logical_or(padding_mask) + attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool() + + return T2SActiveBatch( + request_ids=[state.request_id for state in states], + states=list(states), + x=x_batch, + x_lens=x_lens_tensor, + y_sequences=y_sequences, + prefix_lens=prefix_lens_tensor, + xy_pos=xy_pos, + prefill_attn_mask=attn_mask, + decode_attn_mask=None, + k_cache=None, + v_cache=None, + step_idx=0, + prefill_done=False, + ) + + +def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> torch.Tensor: + last_tokens = torch.stack([seq[-1:] for seq in y_sequences], dim=0) + y_emb = model.ar_audio_embedding(last_tokens) + position_ids = torch.LongTensor([int(seq.shape[0] - 1) for seq in y_sequences]).to(y_emb.device) + _ensure_audio_pe(model, int(position_ids.max().item()), y_emb.dtype, y_emb.device) + pos_emb = model.ar_audio_position.pe[0].index_select(0, position_ids).unsqueeze(1) + return y_emb * model.ar_audio_position.x_scale + model.ar_audio_position.alpha * pos_emb.to( + dtype=y_emb.dtype, device=y_emb.device + ) + + +def _sample_per_request( + model: Any, + active_batch: T2SActiveBatch, + logits: torch.Tensor, + max_steps: int, +) -> Tuple[List[T2SFinishedItem], List[int], List[torch.LongTensor]]: + finished_items: List[T2SFinishedItem] = [] + keep_indices: List[int] = [] + updated_sequences: List[torch.LongTensor] = [] + + step_idx = active_batch.step_idx + for batch_index, state in enumerate(active_batch.states): + logits_i = logits[batch_index : batch_index + 1].clone() + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: + finish_reason = "early_stop" + elif step_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=step_idx, + finish_reason=finish_reason, + ) + ) + else: + keep_indices.append(batch_index) + updated_sequences.append(new_history) + + return finished_items, keep_indices, updated_sequences + + +def decode_one_step( + model: Any, + active_batch: T2SActiveBatch, + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not active_batch.prefill_done: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( + active_batch.xy_pos, active_batch.prefill_attn_mask, None + ) + active_batch.decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + active_batch.prefill_done = True + else: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( + active_batch.xy_pos, + active_batch.k_cache, + active_batch.v_cache, + active_batch.decode_attn_mask, + ) + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False) + + logits = model.ar_predict_layer(xy_dec[:, -1]) + if active_batch.step_idx < 11: + logits = logits[:, :-1] + + finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) + if len(keep_indices) == 0: + return None, finished_items + + device = logits.device + keep_tensor = torch.LongTensor(keep_indices).to(device) + active_batch.request_ids = [active_batch.request_ids[i] for i in keep_indices] + active_batch.states = [active_batch.states[i] for i in keep_indices] + active_batch.y_sequences = updated_sequences + active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if active_batch.k_cache is not None and active_batch.v_cache is not None: + for cache_index in range(len(active_batch.k_cache)): + active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) + active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + active_batch.step_idx += 1 + return active_batch, finished_items + + +def run_scheduler_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + return run_scheduler_continuous(model, states, max_steps=max_steps) + + +def _pad_cache_left(cache: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - cache.shape[1] + if pad_len <= 0: + return cache + return F.pad(cache, (0, 0, pad_len, 0), value=0) + + +def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - mask.shape[-1] + if pad_len <= 0: + return mask + return F.pad(mask, (pad_len, 0), value=True) + + +def run_prefill_step( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not states: + return [], [] + + active_batch = build_prefill_batch(model, states) + xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) + decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + running_requests: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, state in enumerate(states): + logits_i = logits[batch_index : batch_index + 1].clone() + if 0 < 11: + logits_i = logits_i[:, :-1] + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + prefix_len = int(active_batch.prefix_lens[batch_index].item()) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - prefix_len) > state.early_stop_num: + finish_reason = "early_stop" + elif 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=0, + finish_reason=finish_reason, + ) + ) + continue + + real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len + request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + + running_requests.append( + T2SRunningRequest( + state=state, + y_sequence=new_history, + prefix_len=prefix_len, + decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=1, + ) + ) + + return running_requests, finished_items + + +def _build_decode_batch_from_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) + max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) + max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests) + num_layers = len(running_requests[0].k_cache) + + batched_k_cache: List[torch.Tensor] = [] + batched_v_cache: List[torch.Tensor] = [] + for layer_index in range(num_layers): + batched_k_cache.append( + torch.cat([_pad_cache_left(item.k_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + batched_v_cache.append( + torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests], + dim=0, + ) + return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask + + +def run_decode_step_for_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not running_requests: + return [], [] + + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = _build_decode_batch_from_running( + model, running_requests + ) + xy_dec, next_k_cache, next_v_cache = model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + next_running: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, running_request in enumerate(running_requests): + current_idx = running_request.step_idx + logits_i = logits[batch_index : batch_index + 1].clone() + if current_idx < 11: + logits_i = logits_i[:, :-1] + sampled = sample( + logits_i, + running_request.y_sequence.unsqueeze(0), + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + repetition_penalty=running_request.state.repetition_penalty, + temperature=running_request.state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if running_request.state.early_stop_num != -1 and (new_history.shape[0] - running_request.prefix_len) > running_request.state.early_stop_num: + finish_reason = "early_stop" + elif current_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=running_request.state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=current_idx, + finish_reason=finish_reason, + ) + ) + continue + + real_next_kv_len = running_request.k_cache[0].shape[1] + 1 + request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + next_running.append( + T2SRunningRequest( + state=running_request.state, + y_sequence=new_history, + prefix_len=running_request.prefix_len, + decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=current_idx + 1, + ) + ) + + return next_running, finished_items + + +def run_scheduler_continuous( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) + running_requests: List[T2SRunningRequest] = [] + finished: List[T2SFinishedItem] = [] + current_tick = 0 + + while pending or running_requests: + admitted: List[T2SRequestState] = [] + while pending and pending[0].ready_step <= current_tick: + admitted.append(pending.pop(0)) + + admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps) + finished.extend(admitted_finished) + + if running_requests: + running_requests, step_finished = run_decode_step_for_running( + model, + running_requests, + max_steps=max_steps, + ) + finished.extend(step_finished) + + running_requests.extend(admitted_running) + + if not running_requests and pending: + current_tick = max(current_tick + 1, pending[0].ready_step) + continue + + current_tick += 1 + + finished.sort(key=lambda item: item.request_id) + return finished diff --git a/api_v3.py b/api_v3.py new file mode 100644 index 00000000..9d250119 --- /dev/null +++ b/api_v3.py @@ -0,0 +1,1170 @@ +""" +# WebAPI文档 + +` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` + +## 执行参数: + `-a` - `绑定地址, 默认"127.0.0.1"` + `-p` - `绑定端口, 默认9880` + `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"` + +## 调用: + +### 推理 + +endpoint: `/tts` +GET: +``` +http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true +``` + +POST: +```json +{ + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: +``` +http://127.0.0.1:9880/control?command=restart +``` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + + +### 切换GPT模型 + +endpoint: `/set_gpt_weights` + +GET: +``` +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +``` +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 切换Sovits模型 + +endpoint: `/set_sovits_weights` + +GET: +``` +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +``` + +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + +""" + +import asyncio +import os +import sys +import time +import traceback +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Generator, List, Union + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import argparse +import subprocess +import wave +import signal +import numpy as np +import soundfile as sf +import torch +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from io import BytesIO +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + SchedulerRequestSpec, + T2SFinishedItem, + T2SRunningRequest, + T2SRequestState, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, + run_scheduler_continuous, +) +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from pydantic import BaseModel +import threading + +# print(sys.path) +i18n = I18nAuto() +cut_method_names = get_cut_method_names() + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") +parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") +args = parser.parse_args() +config_path = args.tts_config +# device = args.device +port = args.port +host = args.bind_addr +argv = sys.argv + +if config_path in [None, ""]: + config_path = "GPT-SoVITS/configs/tts_infer.yaml" + +tts_config = TTS_Config(config_path) +print(tts_config) +tts_pipeline = TTS(tts_config) + +APP = FastAPI() + + +class TTS_Request(BaseModel): + text: str = None + text_lang: str = None + ref_audio_path: str = None + aux_ref_audio_paths: list = None + prompt_lang: str = None + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = True + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: Union[bool, int] = False + parallel_infer: bool = True + repetition_penalty: float = 1.35 + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + + +class Scheduler_Debug_Request_Item(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + + +class Scheduler_Debug_Request(BaseModel): + requests: List[Scheduler_Debug_Request_Item] + max_steps: int = 1500 + seed: int = -1 + + +class Scheduler_Submit_Request(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + speed_factor: float = 1.0 + sample_steps: int = 32 + media_type: str = "wav" + timeout_sec: float = 30.0 + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + prepare_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + decode_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + + +class SchedulerDebugWorker: + def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): + self.tts = tts + self.max_steps = max_steps + self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 + self.prepare_lock = threading.Lock() + self.condition = threading.Condition() + self.pending_jobs: List[SchedulerPendingJob] = [] + self.running_requests: List[T2SRunningRequest] = [] + self.job_map: dict[str, SchedulerPendingJob] = {} + self.total_finished = 0 + self.total_submitted = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) + self.worker_thread.start() + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: + with self.prepare_lock: + return prepare_request_state(self.tts, spec) + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_ms: float, + prepare_wall_ms: float, + ) -> SchedulerPendingJob: + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_ms=float(prepare_ms), + prepare_wall_ms=float(prepare_wall_ms), + ) + with self.condition: + self.pending_jobs.append(job) + self.job_map[job.request_id] = job + self.total_submitted += 1 + self.condition.notify_all() + return job + + def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None and tracked_job.first_schedule_time is None: + tracked_job.first_schedule_time = started_at + + def _add_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.prefill_ms += elapsed_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.decode_ms += elapsed_ms + job.decode_steps += 1 + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) + phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device) + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=job.state.prompt_semantic, + prompt_phones=job.state.prompt_phones, + refer_spec=job.state.refer_spec, + raw_audio=job.state.raw_audio, + raw_sr=job.state.raw_sr, + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def get_state(self) -> dict: + with self.condition: + return { + "pending_jobs": len(self.pending_jobs), + "running_requests": len(self.running_requests), + "tracked_jobs": len(self.job_map), + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "max_steps": self.max_steps, + "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), + } + + def _finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for item in items: + job = self.job_map.get(item.request_id) + if job is not None: + jobs_to_finalize.append((job, item)) + + for job, item in jobs_to_finalize: + try: + self._sync_device() + synth_start = time.perf_counter() + sample_rate, audio_data = self._synthesize_finished_audio(job, item) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + except Exception as exc: + self._finalize_error([item.request_id], str(exc)) + continue + + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + continue + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + job.synth_ms += synth_ms + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "decode_ms": job.decode_ms, + "synth_ms": job.synth_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is None: + continue + job.error = error + job.done_event.set() + self.job_map.pop(request_id, None) + self.total_finished += 1 + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and not self.running_requests: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def _run_loop(self) -> None: + while True: + wait_for_batch = len(self.running_requests) == 0 + pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) + + if pending_jobs: + try: + self._sync_device() + prefill_start = time.perf_counter() + self._mark_prefill_started(pending_jobs, prefill_start) + admitted_running, admitted_finished = run_prefill_step( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) + self._finalize_finished(admitted_finished) + self.running_requests.extend(admitted_running) + except Exception as exc: + self._finalize_error([job.request_id for job in pending_jobs], str(exc)) + + if self.running_requests: + try: + active_request_ids = [item.state.request_id for item in self.running_requests] + self._sync_device() + decode_start = time.perf_counter() + self.running_requests, step_finished = run_decode_step_for_running( + self.tts.t2s_model.model, + self.running_requests, + max_steps=self.max_steps, + ) + self._sync_device() + self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) + self._finalize_finished(step_finished) + except Exception as exc: + self._finalize_error(active_request_ids, str(exc)) + self.running_requests = [] + continue + + if not pending_jobs: + time.sleep(self.micro_batch_wait_s) + + +scheduler_debug_worker = SchedulerDebugWorker(tts_pipeline) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + "192k", # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + # This will create a wave header then append the frame input + # It should be first on a streaming wav file + # Other frames better should not have it (else you will hear some artifacts each chunk start) + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + return wav_buf.read() + + +def handle_control(command: str): + if command == "restart": + os.execl(sys.executable, sys.executable, *argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def check_params(req: dict): + text: str = req.get("text", "") + text_lang: str = req.get("text_lang", "") + ref_audio_path: str = req.get("ref_audio_path", "") + streaming_mode: bool = req.get("streaming_mode", False) + media_type: str = req.get("media_type", "wav") + prompt_lang: str = req.get("prompt_lang", "") + text_split_method: str = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if text_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text_lang is required"}) + elif text_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, + ) + if prompt_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) + elif prompt_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, + ) + if media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) + # elif media_type == "ogg" and not streaming_mode: + # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + + if text_split_method not in cut_method_names: + return JSONResponse( + status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} + ) + + return None + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def build_scheduler_request_specs(request_items: List[Scheduler_Debug_Request_Item]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(request_items): + payload = item.dict() + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request[{index}] 参数非法: {detail}") + specs.append( + SchedulerRequestSpec( + request_id=payload["request_id"] or f"req_{index:03d}", + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=int(payload["ready_step"]), + ) + ) + return specs + + +def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return [scheduler_debug_worker.prepare_state(spec) for spec in specs] + + +def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: + payload = request.dict() + request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": float(payload["speed_factor"]), + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": payload["media_type"], + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": int(payload["sample_steps"]), + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request 参数非法: {detail}") + return SchedulerRequestSpec( + request_id=request_id, + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=0, + ) + + +async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): + try: + set_scheduler_seed(request.seed) + specs = build_scheduler_request_specs(request.requests) + states = await asyncio.to_thread(prepare_scheduler_states_batch, specs) + finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) + return JSONResponse( + status_code=200, + content={ + "message": "success", + "request_count": len(states), + "max_steps": int(request.max_steps), + "requests": summarize_scheduler_states(states), + "finished": summarize_scheduler_finished(finished), + }, + ) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler debug failed", "Exception": str(e)}, + ) + + +async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): + try: + request_start = time.perf_counter() + spec = build_scheduler_submit_spec(request) + prepare_start = time.perf_counter() + state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec) + prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0 + prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + job = scheduler_debug_worker.submit( + state, + speed_factor=float(request.speed_factor), + sample_steps=int(request.sample_steps), + media_type=request.media_type, + prepare_ms=prepare_ms, + prepare_wall_ms=prepare_wall_ms, + ) + timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec)) + if not timeout_ok: + return JSONResponse( + status_code=202, + content={ + "message": "queued", + "request_id": job.request_id, + "timings": { + "prepare_ms": prepare_ms, + "prepare_wall_ms": prepare_wall_ms, + "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), + }, + "worker_state": scheduler_debug_worker.get_state(), + }, + ) + if job.error is not None: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "request_id": job.request_id, "Exception": job.error}, + ) + if job.audio_data is None or job.sample_rate is None: + return JSONResponse( + status_code=500, + content={ + "message": "scheduler submit failed", + "request_id": job.request_id, + "Exception": "job finished without audio payload", + }, + ) + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_ms = (time.perf_counter() - pack_start) * 1000.0 + job.pack_ms = pack_ms + request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + headers = { + "X-Request-Id": job.request_id, + "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", + "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", + "X-Queue-Wait-Ms": ( + f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Prepare-Ms": f"{prepare_ms:.3f}", + "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Pack-Ms": f"{pack_ms:.3f}", + "X-Worker-Total-Ms": ( + f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Request-Total-Ms": f"{request_total_ms:.3f}", + "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", + } + if job.result is not None: + prepare_profile = job.result.get("prepare_profile", {}) + headers.update( + { + "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", + "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + } + ) + return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "Exception": str(e)}, + ) + + +async def tts_handle(req: dict): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + } + returns: + StreamingResponse: audio stream response. + """ + + streaming_mode = req.get("streaming_mode", False) + return_fragment = req.get("return_fragment", False) + media_type = req.get("media_type", "wav") + + check_res = check_params(req) + if check_res is not None: + return check_res + + if streaming_mode == 0: + streaming_mode = False + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 1: + streaming_mode = False + return_fragment = True + fixed_length_chunk = False + elif streaming_mode == 2: + streaming_mode = True + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 3: + streaming_mode = True + return_fragment = False + fixed_length_chunk = True + + else: + return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) + + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + req["fixed_length_chunk"] = fixed_length_chunk + + print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") + + streaming_mode = streaming_mode or return_fragment + + + try: + tts_generator = tts_pipeline.run(req) + + if streaming_mode: + + def streaming_generator(tts_generator: Generator, media_type: str): + if_frist_chunk = True + for sr, chunk in tts_generator: + if if_frist_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + media_type = "raw" + if_frist_chunk = False + yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + + # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" + return StreamingResponse( + streaming_generator( + tts_generator, + media_type, + ), + media_type=f"audio/{media_type}", + ) + + else: + sr, audio_data = next(tts_generator) + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_data, media_type=f"audio/{media_type}") + except Exception as e: + return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) + + +@APP.get("/control") +async def control(command: str = None): + if command is None: + return JSONResponse(status_code=400, content={"message": "command is required"}) + handle_control(command) + + +@APP.get("/tts") +async def tts_get_endpoint( + text: str = None, + text_lang: str = None, + ref_audio_path: str = None, + aux_ref_audio_paths: list = None, + prompt_lang: str = None, + prompt_text: str = "", + top_k: int = 15, + top_p: float = 1, + temperature: float = 1, + text_split_method: str = "cut5", + batch_size: int = 1, + batch_threshold: float = 0.75, + split_bucket: bool = True, + speed_factor: float = 1.0, + fragment_interval: float = 0.3, + seed: int = -1, + media_type: str = "wav", + parallel_infer: bool = True, + repetition_penalty: float = 1.35, + sample_steps: int = 32, + super_sampling: bool = False, + streaming_mode: Union[bool, int] = False, + overlap_length: int = 2, + min_chunk_length: int = 16, +): + req = { + "text": text, + "text_lang": text_lang.lower(), + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": aux_ref_audio_paths, + "prompt_text": prompt_text, + "prompt_lang": prompt_lang.lower(), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size": int(batch_size), + "batch_threshold": float(batch_threshold), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "fragment_interval": fragment_interval, + "seed": seed, + "media_type": media_type, + "streaming_mode": streaming_mode, + "parallel_infer": parallel_infer, + "repetition_penalty": float(repetition_penalty), + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "min_chunk_length": int(min_chunk_length), + } + return await tts_handle(req) + + +@APP.post("/tts") +async def tts_post_endpoint(request: TTS_Request): + req = request.dict() + return await tts_handle(req) + + +@APP.post("/tts_scheduler_debug") +async def tts_scheduler_debug_endpoint(request: Scheduler_Debug_Request): + return await tts_scheduler_debug_handle(request) + + +@APP.post("/tts_scheduler_submit") +async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request): + return await tts_scheduler_submit_handle(request) + + +@APP.get("/tts_scheduler_state") +async def tts_scheduler_state_endpoint(): + return JSONResponse(status_code=200, content={"message": "success", "worker_state": scheduler_debug_worker.get_state()}) + + +@APP.get("/set_refer_audio") +async def set_refer_aduio(refer_audio_path: str = None): + try: + tts_pipeline.set_ref_audio(refer_audio_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +# @APP.post("/set_refer_audio") +# async def set_refer_aduio_post(audio_file: UploadFile = File(...)): +# try: +# # 检查文件类型,确保是音频文件 +# if not audio_file.content_type.startswith("audio/"): +# return JSONResponse(status_code=400, content={"message": "file type is not supported"}) + +# os.makedirs("uploaded_audio", exist_ok=True) +# save_path = os.path.join("uploaded_audio", audio_file.filename) +# # 保存音频文件到服务器上的一个目录 +# with open(save_path , "wb") as buffer: +# buffer.write(await audio_file.read()) + +# tts_pipeline.set_ref_audio(save_path) +# except Exception as e: +# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)}) +# return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_gpt_weights") +async def set_gpt_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) + tts_pipeline.init_t2s_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) + + return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_sovits_weights") +async def set_sovits_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) + tts_pipeline.init_vits_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +if __name__ == "__main__": + try: + if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None + uvicorn.run(app=APP, host=host, port=port, workers=1) + except Exception: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) diff --git a/tools/t2s_scheduler_prototype.py b/tools/t2s_scheduler_prototype.py new file mode 100644 index 00000000..cd4b9c6d --- /dev/null +++ b/tools/t2s_scheduler_prototype.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SFinishedItem, + T2SRequestState, + prepare_request_state, + run_scheduler_continuous, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-lang", type=str, default="en") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run") + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_pipeline(config_path: Path): + try: + from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。" + ) from exc + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8") + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8") + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + use_cuda = str(tts.configs.device).startswith("cuda") + set_seed(args.seed, use_cuda) + + request_specs = load_request_specs(args) + states = [prepare_request_state(tts, spec) for spec in request_specs] + finished = run_scheduler_continuous(model, states, max_steps=args.max_steps) + + summary = { + "request_count": len(states), + "max_steps": args.max_steps, + "requests": summarise_requests(states), + "finished": summarise_finished(finished), + } + output_path = args.output_dir / "scheduler_prototype_summary.json" + output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(summary, ensure_ascii=False, indent=2)) + print(f"[saved] {output_path}") + + +if __name__ == "__main__": + try: + main() + except ModuleNotFoundError as exc: + print(f"[error] {exc}") + raise SystemExit(1) from None From d245eb169c57c50ea55444bdf104ded247058872 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 01:42:04 +0800 Subject: [PATCH 5/9] Refactor T2S scheduler and inference handling to improve attention mask management and memory tracking. Update T2SRunningRequest and T2SActiveBatch classes to include optional key padding masks. Introduce new benchmarking tools for API performance and memory usage analysis, enhancing overall system efficiency. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 1 + GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 82 +- tools/bench_api_v3_scheduler_submit.py | 250 ++++++ tools/t2s_memory_breakdown.py | 887 +++++++++++++++++++++ 4 files changed, 1192 insertions(+), 28 deletions(-) create mode 100644 tools/bench_api_v3_scheduler_submit.py create mode 100644 tools/t2s_memory_breakdown.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index bd4953df..2fd0df35 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1781,6 +1781,7 @@ class TTS: return audio + @torch.inference_mode() def synthesize_audio_request_local( self, semantic_tokens: torch.Tensor, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index e94a72c7..c8643991 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -70,7 +70,7 @@ class T2SRunningRequest: state: T2SRequestState y_sequence: torch.LongTensor prefix_len: int - decode_attn_mask: torch.Tensor + decode_attn_mask: Optional[torch.Tensor] k_cache: List[torch.Tensor] v_cache: List[torch.Tensor] step_idx: int @@ -93,6 +93,7 @@ class T2SActiveBatch: y_sequences: List[torch.LongTensor] prefix_lens: torch.LongTensor xy_pos: torch.Tensor + key_padding_mask: torch.Tensor prefill_attn_mask: torch.Tensor decode_attn_mask: Optional[torch.Tensor] k_cache: Optional[List[torch.Tensor]] @@ -110,6 +111,7 @@ def normalize_sentence(text: str, language: str) -> str: return text +@torch.inference_mode() def prepare_request_state( tts: Any, spec: SchedulerRequestSpec, @@ -212,6 +214,7 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: ) +@torch.inference_mode() def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: x_items: List[torch.Tensor] = [] y_pos_items: List[torch.Tensor] = [] @@ -240,22 +243,19 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct device = x_batch.device x_lens_tensor = torch.LongTensor(x_lens).to(device) prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) - batch_size = len(states) src_len = max_x_len + max_prefix_len x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) - padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1) + key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool() x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) y_mask = F.pad( torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), (max_x_len, 0), value=False, ) - causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1) - padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1) - attn_mask = causal_mask.logical_or(padding_mask) - attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool() + causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0) + attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1) return T2SActiveBatch( request_ids=[state.request_id for state in states], @@ -265,6 +265,7 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct y_sequences=y_sequences, prefix_lens=prefix_lens_tensor, xy_pos=xy_pos, + key_padding_mask=key_padding_mask, prefill_attn_mask=attn_mask, decode_attn_mask=None, k_cache=None, @@ -322,10 +323,11 @@ def _sample_per_request( finish_reason = "eos_argmax" if finish_reason is not None: + prefix_len = int(active_batch.prefix_lens[batch_index].item()) finished_items.append( T2SFinishedItem( request_id=state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[prefix_len:-1].clone(), finish_idx=step_idx, finish_reason=finish_reason, ) @@ -346,11 +348,7 @@ def decode_one_step( xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( active_batch.xy_pos, active_batch.prefill_attn_mask, None ) - active_batch.decode_attn_mask = F.pad( - active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), - (0, 1), - value=False, - ) + active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) active_batch.prefill_done = True else: xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( @@ -411,6 +409,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: return F.pad(mask, (pad_len, 0), value=True) +def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + if running_request.decode_attn_mask is not None: + return running_request.decode_attn_mask + current_mask_len = running_request.k_cache[0].shape[1] + 1 + return torch.zeros( + (1, 1, 1, current_mask_len), + dtype=torch.bool, + device=running_request.k_cache[0].device, + ) + + +@torch.inference_mode() def run_prefill_step( model: Any, states: Sequence[T2SRequestState], @@ -421,11 +431,9 @@ def run_prefill_step( active_batch = build_prefill_batch(model, states) xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) - decode_attn_mask = F.pad( - active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), - (0, 1), - value=False, - ) + decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if len(states) == 1 and not decode_attn_mask.any().item(): + decode_attn_mask = None logits = model.ar_predict_layer(xy_dec[:, -1]) running_requests: List[T2SRunningRequest] = [] @@ -463,7 +471,7 @@ def run_prefill_step( finished_items.append( T2SFinishedItem( request_id=state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[prefix_len:-1].clone(), finish_idx=0, finish_reason=finish_reason, ) @@ -479,7 +487,11 @@ def run_prefill_step( state=state, y_sequence=new_history, prefix_len=prefix_len, - decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(), + decode_attn_mask=( + None + if decode_attn_mask is None + else decode_attn_mask[batch_index : batch_index + 1].clone() + ), k_cache=request_k_cache, v_cache=request_v_cache, step_idx=1, @@ -492,10 +504,9 @@ def run_prefill_step( def _build_decode_batch_from_running( model: Any, running_requests: Sequence[T2SRunningRequest], -) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]: +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]: xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) - max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests) num_layers = len(running_requests[0].k_cache) batched_k_cache: List[torch.Tensor] = [] @@ -508,13 +519,19 @@ def _build_decode_batch_from_running( torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) ) - batched_decode_attn_mask = torch.cat( - [_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests], - dim=0, - ) + if all(item.decode_attn_mask is None for item in running_requests): + batched_decode_attn_mask = None + else: + materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests] + max_mask_len = max(mask.shape[-1] for mask in materialized_masks) + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks], + dim=0, + ) return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask +@torch.inference_mode() def run_decode_step_for_running( model: Any, running_requests: Sequence[T2SRunningRequest], @@ -568,7 +585,7 @@ def run_decode_step_for_running( finished_items.append( T2SFinishedItem( request_id=running_request.state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[running_request.prefix_len:-1].clone(), finish_idx=current_idx, finish_reason=finish_reason, ) @@ -578,12 +595,20 @@ def run_decode_step_for_running( real_next_kv_len = running_request.k_cache[0].shape[1] + 1 request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + if batched_decode_attn_mask is None: + next_decode_attn_mask = None + else: + current_decode_mask_len = running_request.k_cache[0].shape[1] + 1 + current_decode_attn_mask = batched_decode_attn_mask[ + batch_index : batch_index + 1, :, :, -current_decode_mask_len: + ] + next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) next_running.append( T2SRunningRequest( state=running_request.state, y_sequence=new_history, prefix_len=running_request.prefix_len, - decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False), + decode_attn_mask=next_decode_attn_mask, k_cache=request_k_cache, v_cache=request_v_cache, step_idx=current_idx + 1, @@ -593,6 +618,7 @@ def run_decode_step_for_running( return next_running, finished_items +@torch.inference_mode() def run_scheduler_continuous( model: Any, states: Sequence[T2SRequestState], diff --git a/tools/bench_api_v3_scheduler_submit.py b/tools/bench_api_v3_scheduler_submit.py new file mode 100644 index 00000000..c16468e1 --- /dev/null +++ b/tools/bench_api_v3_scheduler_submit.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import asyncio +import json +import subprocess +import threading +import time +import wave +from pathlib import Path +from typing import Any, Dict, List, Optional + +import httpx + +ROOT_DIR = Path(__file__).resolve().parents[1] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.") + parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880") + parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit") + parser.add_argument("--concurrency", type=int, required=True) + parser.add_argument("--timeout-sec", type=float, default=120.0) + parser.add_argument("--server-pid", type=int, default=None) + parser.add_argument("--poll-interval-sec", type=float, default=0.1) + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--media-type", type=str, default="wav") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--sample-steps", type=int, default=32) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench") + return parser.parse_args() + + +def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]: + wav_paths_all = sorted(args.wav_dir.glob("*.wav")) + wav_paths: List[Path] = [] + for wav_path in wav_paths_all: + with wave.open(str(wav_path), "rb") as handle: + duration = handle.getnframes() / float(handle.getframerate()) + if 3.0 <= duration <= 10.0: + wav_paths.append(wav_path) + if not wav_paths: + raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}") + text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if not text_lines: + raise ValueError(f"没有找到有效文本行: {args.text_file}") + + requests: List[Dict[str, Any]] = [] + for index in range(args.concurrency): + wav_path = wav_paths[index % len(wav_paths)] + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"缺少参考文本: {lab_path}") + requests.append( + { + "request_id": f"bench_{args.concurrency:03d}_{index:03d}", + "text": text_lines[index % len(text_lines)], + "text_lang": args.text_lang, + "ref_audio_path": str(wav_path), + "prompt_lang": args.prompt_lang, + "prompt_text": lab_path.read_text(encoding="utf-8").strip(), + "top_k": int(args.top_k), + "top_p": float(args.top_p), + "temperature": float(args.temperature), + "repetition_penalty": float(args.repetition_penalty), + "sample_steps": int(args.sample_steps), + "media_type": args.media_type, + "timeout_sec": float(args.timeout_sec), + } + ) + return requests + + +class GpuMemoryPoller: + def __init__(self, server_pid: Optional[int], interval_sec: float): + self.server_pid = server_pid + self.interval_sec = interval_sec + self._stop = threading.Event() + self.samples: List[Dict[str, Any]] = [] + self.thread: Optional[threading.Thread] = None + + def _query_memory_mb(self) -> Optional[int]: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-compute-apps=pid,used_gpu_memory", + "--format=csv,noheader,nounits", + ], + check=True, + capture_output=True, + text=True, + ) + except Exception: + return None + total = 0 + found = False + for line in result.stdout.splitlines(): + line = line.strip() + if not line: + continue + parts = [item.strip() for item in line.split(",")] + if len(parts) != 2: + continue + try: + pid = int(parts[0]) + used_mb = int(parts[1]) + except ValueError: + continue + if self.server_pid is None or pid == self.server_pid: + total += used_mb + found = True + if self.server_pid is None: + return total + return total if found else 0 + + def _run(self) -> None: + while not self._stop.is_set(): + used_mb = self._query_memory_mb() + self.samples.append({"ts": time.time(), "used_mb": used_mb}) + self._stop.wait(self.interval_sec) + + def start(self) -> None: + self.thread = threading.Thread(target=self._run, daemon=True) + self.thread.start() + + def stop(self) -> None: + self._stop.set() + if self.thread is not None: + self.thread.join(timeout=2.0) + + def summary(self) -> Dict[str, Any]: + valid = [item for item in self.samples if item["used_mb"] is not None] + peak = max(valid, key=lambda item: item["used_mb"]) if valid else None + first = valid[0] if valid else None + last = valid[-1] if valid else None + return { + "server_pid": self.server_pid, + "sample_count": int(len(self.samples)), + "start_used_mb": None if first is None else int(first["used_mb"]), + "peak_used_mb": None if peak is None else int(peak["used_mb"]), + "peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]), + "end_used_mb": None if last is None else int(last["used_mb"]), + "peak_ts": None if peak is None else float(peak["ts"]), + "samples": self.samples, + } + + +async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]: + started = time.perf_counter() + try: + response = await client.post(url, json=payload) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + item = { + "request_id": payload["request_id"], + "status_code": int(response.status_code), + "elapsed_ms": float(elapsed_ms), + "content_type": response.headers.get("content-type"), + "audio_bytes": int(len(response.content)), + "headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")}, + } + if response.status_code != 200: + try: + item["error_body"] = response.json() + except Exception: + item["error_body"] = response.text + return item + except Exception as exc: + return { + "request_id": payload["request_id"], + "status_code": -1, + "elapsed_ms": float((time.perf_counter() - started) * 1000.0), + "exception": repr(exc), + } + + +async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]: + payloads = load_requests(args) + url = args.base_url.rstrip("/") + args.endpoint + poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec) + + limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency) + timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0) + + started = time.perf_counter() + poller.start() + try: + async with httpx.AsyncClient(limits=limits, timeout=timeout) as client: + results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads]) + finally: + poller.stop() + wall_ms = (time.perf_counter() - started) * 1000.0 + + ok_results = [item for item in results if item["status_code"] == 200] + failed_results = [item for item in results if item["status_code"] != 200] + request_total_ms = [] + worker_total_ms = [] + for item in ok_results: + headers = item.get("headers", {}) + if "x-request-total-ms" in headers: + request_total_ms.append(float(headers["x-request-total-ms"])) + if "x-worker-total-ms" in headers: + worker_total_ms.append(float(headers["x-worker-total-ms"])) + + return { + "concurrency": int(args.concurrency), + "server_pid": args.server_pid, + "request_count": int(len(payloads)), + "wall_ms": float(wall_ms), + "success_count": int(len(ok_results)), + "failure_count": int(len(failed_results)), + "request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None, + "request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None, + "worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None, + "worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None, + "gpu_memory": poller.summary(), + "results": results, + } + + +def main() -> None: + args = parse_args() + output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}" + output_dir.mkdir(parents=True, exist_ok=True) + summary = asyncio.run(run_benchmark(args)) + summary_path = output_dir / "summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps({ + "concurrency": summary["concurrency"], + "success_count": summary["success_count"], + "failure_count": summary["failure_count"], + "wall_ms": summary["wall_ms"], + "gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"], + "request_total_ms_avg": summary["request_total_ms_avg"], + "request_total_ms_max": summary["request_total_ms_max"], + "summary_path": str(summary_path), + }, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/tools/t2s_memory_breakdown.py b/tools/t2s_memory_breakdown.py new file mode 100644 index 00000000..18127953 --- /dev/null +++ b/tools/t2s_memory_breakdown.py @@ -0,0 +1,887 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import gc +import contextlib +import json +import random +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402 +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SRequestState, + T2SRunningRequest, + _build_decode_batch_from_running, + build_prefill_batch, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"]) + parser.add_argument("--auto-count", type=int, default=4) + parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--warmup", action="store_true", default=False) + parser.add_argument("--worker-rounds", type=int, default=1) + parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"]) + parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False) + parser.add_argument( + "--output-dir", + type=Path, + default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1", + ) + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +def bytes_to_mb(num_bytes: int) -> float: + return float(num_bytes) / (1024.0 * 1024.0) + + +def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int: + if tensor is None: + return 0 + return int(tensor.numel() * tensor.element_size()) + + +def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int: + return int(sum(tensor_nbytes(item) for item in items)) + + +def model_nbytes(module: torch.nn.Module) -> int: + total = 0 + for parameter in module.parameters(): + total += tensor_nbytes(parameter) + for buffer in module.buffers(): + total += tensor_nbytes(buffer) + return int(total) + + +def build_module_weight_summary(tts: TTS) -> Dict[str, Any]: + modules = { + "t2s_model": tts.t2s_model, + "t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None, + "vits_model": tts.vits_model, + "bert_model": tts.bert_model, + "cnhuhbert_model": tts.cnhuhbert_model, + "vocoder": tts.vocoder, + "sv_model": tts.sv_model, + } + by_module = {} + total_bytes = 0 + for name, module in modules.items(): + module_bytes = model_nbytes(module) if module is not None else 0 + by_module[name] = { + "bytes": int(module_bytes), + "mb": bytes_to_mb(module_bytes), + } + total_bytes += module_bytes + return { + "by_module": by_module, + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + } + + +def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]: + storages: Dict[int, Dict[str, Any]] = {} + tensor_views: List[Dict[str, Any]] = [] + for obj in gc.get_objects(): + try: + tensor = None + if torch.is_tensor(obj): + tensor = obj + elif hasattr(obj, "data") and torch.is_tensor(obj.data): + tensor = obj.data + if tensor is None or not tensor.is_cuda: + continue + storage = tensor.untyped_storage() + storage_ptr = int(storage.data_ptr()) + if storage_ptr not in storages: + storages[storage_ptr] = { + "storage_ptr": storage_ptr, + "storage_bytes": int(storage.nbytes()), + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + "device": str(tensor.device), + } + tensor_views.append( + { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "bytes": tensor_nbytes(tensor), + "device": str(tensor.device), + } + ) + except Exception: + continue + storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True) + tensor_views.sort(key=lambda item: item["bytes"], reverse=True) + return { + "unique_storage_count": int(len(storage_list)), + "unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)), + "unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)), + "top_storages": storage_list[:top_k], + "top_tensor_views": tensor_views[:top_k], + } + + +def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip() + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count] + if len(wav_paths) < args.auto_count: + raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav") + text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if len(text_lines) < args.auto_count: + raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本") + specs: List[SchedulerRequestSpec] = [] + for index, wav_path in enumerate(wav_paths): + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"找不到参考文本 {lab_path}") + specs.append( + SchedulerRequestSpec( + request_id=f"req_{index:03d}", + ref_audio_path=wav_path, + prompt_text=lab_path.read_text(encoding="utf-8").strip(), + prompt_lang="zh", + text=text_lines[index], + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ) + return specs + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8").strip() + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + if args.scenario == "single": + return build_single_spec(args) + return build_auto_specs(args) + + +def load_pipeline(config_path: Path) -> TTS: + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def cuda_mem_snapshot(device: Any) -> Dict[str, float]: + if not (str(device).startswith("cuda") and torch.cuda.is_available()): + return { + "allocated_mb": 0.0, + "reserved_mb": 0.0, + "max_allocated_mb": 0.0, + "max_reserved_mb": 0.0, + } + _sync_device(device) + return { + "allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)), + "reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)), + "max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)), + "max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)), + } + + +def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]: + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.reset_peak_memory_stats(device) + before = cuda_mem_snapshot(device) + started = time.perf_counter() + result = fn() + _sync_device(device) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + after = cuda_mem_snapshot(device) + after["elapsed_ms"] = float(elapsed_ms) + after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"]) + after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"]) + after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0)) + return result, after + + +class GlobalPeakRecorder: + def __init__(self, device: Any): + self.device = device + self.checkpoints: List[Dict[str, Any]] = [] + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + def record(self, label: str, **extra: Any) -> None: + snapshot = cuda_mem_snapshot(self.device) + snapshot["label"] = label + snapshot.update(extra) + self.checkpoints.append(snapshot) + + def summary(self) -> Dict[str, Any]: + peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None + return { + "peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]), + "peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]), + "peak_label": None if peak is None else peak["label"], + "checkpoints": self.checkpoints, + } + + +def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]: + per_request = [] + total = { + "phones_bytes": 0, + "prompt_phones_bytes": 0, + "all_phones_bytes": 0, + "all_bert_features_bytes": 0, + "prompt_semantic_bytes": 0, + "refer_spec_bytes": 0, + "raw_audio_bytes": 0, + "audio_16k_bytes": 0, + } + for state in states: + spec_audio, audio_16k = state.refer_spec + item = { + "request_id": state.request_id, + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "phones_len": int(state.phones.shape[0]), + "all_phones_len": int(state.all_phones.shape[0]), + "bert_frames": int(state.all_bert_features.shape[-1]), + "phones_bytes": tensor_nbytes(state.phones), + "prompt_phones_bytes": tensor_nbytes(state.prompt_phones), + "all_phones_bytes": tensor_nbytes(state.all_phones), + "all_bert_features_bytes": tensor_nbytes(state.all_bert_features), + "prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic), + "refer_spec_bytes": tensor_nbytes(spec_audio), + "audio_16k_bytes": tensor_nbytes(audio_16k), + "raw_audio_bytes": tensor_nbytes(state.raw_audio), + } + for key in total: + total[key] += int(item[key]) + per_request.append(item) + total["total_bytes"] = int(sum(total.values())) + total["total_mb"] = bytes_to_mb(total["total_bytes"]) + return {"per_request": per_request, "total": total} + + +def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]: + y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences)) + fields = { + "x_bytes": tensor_nbytes(active_batch.x), + "x_lens_bytes": tensor_nbytes(active_batch.x_lens), + "prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens), + "xy_pos_bytes": tensor_nbytes(active_batch.xy_pos), + "key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask), + "prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask), + "y_sequence_bytes": y_sequence_bytes, + } + fields["total_bytes"] = int(sum(fields.values())) + fields["total_mb"] = bytes_to_mb(fields["total_bytes"]) + fields["batch_size"] = int(len(active_batch.states)) + fields["max_x_len"] = int(active_batch.x.shape[1]) + fields["src_len"] = int(active_batch.xy_pos.shape[1]) + fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape) + return fields + + +def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]: + per_request = [] + total_private_k_bytes = 0 + total_private_v_bytes = 0 + total_decode_mask_bytes = 0 + total_y_sequence_bytes = 0 + for item in running_requests: + k_bytes = tensor_list_nbytes(item.k_cache) + v_bytes = tensor_list_nbytes(item.v_cache) + mask_bytes = tensor_nbytes(item.decode_attn_mask) + y_bytes = tensor_nbytes(item.y_sequence) + total_private_k_bytes += k_bytes + total_private_v_bytes += v_bytes + total_decode_mask_bytes += mask_bytes + total_y_sequence_bytes += y_bytes + per_request.append( + { + "request_id": item.state.request_id, + "step_idx": int(item.step_idx), + "prefix_len": int(item.prefix_len), + "history_len": int(item.y_sequence.shape[0]), + "kv_len": int(item.k_cache[0].shape[1]), + "k_cache_bytes": k_bytes, + "v_cache_bytes": v_bytes, + "decode_mask_bytes": mask_bytes, + "y_sequence_bytes": y_bytes, + } + ) + total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes + return { + "per_request": per_request, + "totals": { + "private_k_cache_bytes": int(total_private_k_bytes), + "private_v_cache_bytes": int(total_private_v_bytes), + "private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes), + "decode_mask_bytes": int(total_decode_mask_bytes), + "y_sequence_bytes": int(total_y_sequence_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + }, + } + + +def summarise_decode_batch( + xy_pos: torch.Tensor, + batched_k_cache: Sequence[torch.Tensor], + batched_v_cache: Sequence[torch.Tensor], + batched_decode_attn_mask: Optional[torch.Tensor], + running_requests: Sequence[T2SRunningRequest], +) -> Dict[str, Any]: + private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests)) + private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests)) + batched_k_bytes = tensor_list_nbytes(batched_k_cache) + batched_v_bytes = tensor_list_nbytes(batched_v_cache) + batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask) + xy_pos_bytes = tensor_nbytes(xy_pos) + total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes + return { + "batch_size": int(len(running_requests)), + "xy_pos_bytes": int(xy_pos_bytes), + "batched_k_cache_bytes": int(batched_k_bytes), + "batched_v_cache_bytes": int(batched_v_bytes), + "batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes), + "batched_decode_mask_bytes": int(batched_mask_bytes), + "private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes), + "kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_pos_shape": list(xy_pos.shape), + "batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape), + "layer_k_cache_shape": list(batched_k_cache[0].shape), + } + + +def summarise_decode_outputs( + xy_dec: torch.Tensor, + next_k_cache: Sequence[torch.Tensor], + next_v_cache: Sequence[torch.Tensor], +) -> Dict[str, Any]: + xy_dec_bytes = tensor_nbytes(xy_dec) + next_k_bytes = tensor_list_nbytes(next_k_cache) + next_v_bytes = tensor_list_nbytes(next_v_cache) + total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes + return { + "xy_dec_bytes": int(xy_dec_bytes), + "next_k_cache_bytes": int(next_k_bytes), + "next_v_cache_bytes": int(next_v_bytes), + "next_kv_cache_bytes": int(next_k_bytes + next_v_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_dec_shape": list(xy_dec.shape), + "layer_next_k_cache_shape": list(next_k_cache[0].shape), + } + + +def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]: + ranking = [ + ("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]), + ("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]), + ("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]), + ("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]), + ("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]), + ("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]), + ("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]), + ] + ranking.sort(key=lambda item: item[1], reverse=True) + return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking] + + +def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]: + semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device) + phones = state.phones.unsqueeze(0).to(tts.configs.device) + audio_fragment = tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=state.prompt_semantic, + prompt_phones=state.prompt_phones, + refer_spec=state.refer_spec, + raw_audio=state.raw_audio, + raw_sr=state.raw_sr, + speed=1.0, + sample_steps=32, + ) + output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"] + return tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=1.0, + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + +def simulate_worker_end_to_end( + tts: TTS, + specs: Sequence[SchedulerRequestSpec], + max_steps: int, + rounds: int, + grad_mode: str = "default", +) -> Dict[str, Any]: + device = tts.configs.device + recorder = GlobalPeakRecorder(device) + recorder.record("after_model_load") + + state_map: Dict[str, T2SRequestState] = {} + per_round: List[Dict[str, Any]] = [] + + for round_index in range(rounds): + grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext + with grad_context(): + states = [prepare_request_state(tts, spec) for spec in specs] + state_map = {state.request_id: state for state in states} + recorder.record( + "after_prepare_states", + round_index=int(round_index), + request_count=int(len(states)), + grad_mode=grad_mode, + ) + + pending = list(states) + running_requests: List[T2SRunningRequest] = [] + round_events: List[Dict[str, Any]] = [] + current_tick = 0 + + while pending or running_requests: + admitted = pending + pending = [] + + if admitted: + recorder.record( + "before_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_count=int(len(admitted)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps) + recorder.record( + "after_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_running_count=int(len(admitted_running)), + admitted_finished_count=int(len(admitted_finished)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "prefill", + "admitted_count": int(len(admitted)), + "admitted_running_count": int(len(admitted_running)), + "admitted_finished_count": int(len(admitted_finished)), + } + ) + for item in admitted_finished: + recorder.record( + "before_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + running_requests.extend(admitted_running) + recorder.record( + "after_extend_running", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + + if running_requests: + recorder.record( + "before_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + running_requests, step_finished = run_decode_step_for_running( + tts.t2s_model.model, + running_requests, + max_steps=max_steps, + ) + recorder.record( + "after_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_count=int(len(step_finished)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "decode", + "running_count_after_decode": int(len(running_requests)), + "finished_count": int(len(step_finished)), + } + ) + for item in step_finished: + recorder.record( + "before_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + current_tick += 1 + + recorder.record( + "after_round_complete", + round_index=int(round_index), + running_count=0, + grad_mode=grad_mode, + ) + per_round.append( + { + "round_index": int(round_index), + "events": round_events, + } + ) + + return { + "grad_mode": grad_mode, + "rounds": per_round, + "timeline": recorder.summary(), + } + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + device = tts.configs.device + use_cuda = str(device).startswith("cuda") and torch.cuda.is_available() + set_seed(args.seed, use_cuda) + + specs = load_request_specs(args) + if args.early_stop_num == -1: + for spec in specs: + spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec) + + if args.warmup and specs: + warmup_spec = specs[:1] + _ = [prepare_request_state(tts, spec) for spec in warmup_spec] + gc.collect() + if use_cuda: + torch.cuda.empty_cache() + _sync_device(device) + + states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs]) + request_state_summary = summarise_state_tensors(states) + + active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states)) + prefill_batch_tensor_summary = summarise_prefill_batch(active_batch) + + prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps)) + running_requests, finished_items = prefill_result + running_requests_summary = summarise_running_requests(running_requests) + finished_after_prefill_summary = [ + { + "request_id": item.request_id, + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + } + for item in finished_items + ] + + if not running_requests: + raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}") + + decode_batch_result, decode_batch_mem = stage_run( + device, + lambda: _build_decode_batch_from_running(model, running_requests), + ) + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result + decode_batch_tensor_summary = summarise_decode_batch( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + running_requests, + ) + + decode_out_result, decode_step_mem = stage_run( + device, + lambda: model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ), + ) + xy_dec, next_k_cache, next_v_cache = decode_out_result + decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache) + del active_batch + del running_requests + del finished_items + del xy_pos + del batched_k_cache + del batched_v_cache + del batched_decode_attn_mask + del xy_dec + del next_k_cache + del next_v_cache + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + end_to_end_worker = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode=args.worker_grad_mode, + ) + live_cuda_tensors_after_worker = snapshot_live_cuda_tensors() + worker_inference_mode = None + if args.compare_worker_grad_modes: + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + worker_inference_mode = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode="inference_mode", + ) + + summary = { + "meta": { + "scenario": args.scenario if args.request_manifest is None else "manifest", + "seed": int(args.seed), + "device": str(device), + "dtype": str(next(model.parameters()).dtype), + "request_count": int(len(specs)), + "num_layers": int(model.num_layers), + "num_heads": int(model.num_head), + "model_dim": int(model.model_dim), + "model_weights_mb": bytes_to_mb(model_nbytes(model)), + }, + "loaded_module_weights": build_module_weight_summary(tts), + "requests": [ + { + "request_id": spec.request_id, + "ref_audio_path": str(spec.ref_audio_path), + "prompt_text": spec.prompt_text, + "text": spec.text, + } + for spec in specs + ], + "prepare_stage": { + "memory": prepare_mem, + "request_state": request_state_summary, + }, + "prefill_batch": { + "memory": prefill_batch_mem, + "tensor_bytes": prefill_batch_tensor_summary, + }, + "prefill_step": { + "memory": prefill_step_mem, + "running_requests": running_requests_summary, + "finished_after_prefill": finished_after_prefill_summary, + }, + "decode_batch": { + "memory": decode_batch_mem, + "tensor_bytes": decode_batch_tensor_summary, + }, + "decode_outputs": { + "memory": decode_step_mem, + "tensor_bytes": decode_output_tensor_summary, + }, + "end_to_end_worker": end_to_end_worker, + "live_cuda_tensors_after_worker": live_cuda_tensors_after_worker, + "end_to_end_worker_inference_mode": worker_inference_mode, + } + summary["top_rankings"] = top_rankings(summary) + + summary_path = args.output_dir / "t2s_memory_breakdown_summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + + print(json.dumps(summary["meta"], ensure_ascii=False, indent=2)) + print("[top_rankings]") + for item in summary["top_rankings"]: + print(f"- {item['name']}: {item['mb']:.3f} MB") + print("[worker_peak]") + print( + json.dumps( + { + "peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"], + "peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + if worker_inference_mode is not None: + print("[worker_peak_inference_mode]") + print( + json.dumps( + { + "peak_label": worker_inference_mode["timeline"]["peak_label"], + "peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + print(f"[summary] {summary_path}") + + +if __name__ == "__main__": + main() From 845b181360b5c5f4a5ed827ca01cfd5dfd8f58b1 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 05:19:28 +0800 Subject: [PATCH 6/9] Implement batch processing for BERT and reference semantic tasks in TTS. Introduce StageLimiter for managing concurrent processing and enhance the TTS class with new methods for handling audio and semantic extraction. Update profiling metrics for better performance tracking during inference. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 261 ++++++++++++++--- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 259 +++++++++++------ .../prepare_bert_batch_worker.py | 197 +++++++++++++ .../prepare_ref_semantic_batch_worker.py | 262 ++++++++++++++++++ GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 121 ++++++-- api_v3.py | 64 ++++- 6 files changed, 1024 insertions(+), 140 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 2fd0df35..9c259662 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -5,6 +5,7 @@ import random import sys import time import traceback +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy import torchaudio @@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list from TTS_infer_pack.text_segmentation_method import splits -from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( + PrepareRefSemanticBatchWorker, + prepare_prompt_semantic_wav16k, +) from sv import SV resample_transform_dict = {} @@ -442,11 +448,56 @@ class TTS: "upsample_rate": None, "overlapped_len": None, } + self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2"))) + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None + default_text_cpu_workers = 16 + self.prepare_text_cpu_workers = max( + 0, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))), + ) + self.prepare_text_cpu_executor = None + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_executor = ThreadPoolExecutor( + max_workers=self.prepare_text_cpu_workers, + thread_name_prefix="prepare-text-cpu", + ) self._init_models() + if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": + self.prepare_bert_batch_worker = PrepareBertBatchWorker( + bert_model=self.bert_model, + tokenizer=self.bert_tokenizer, + device=self.configs.device, + stage_limiter=self.prepare_bert_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), + max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + ) + if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") + if ref_max_batch_samples is None: + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000") + self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker( + ssl_model=self.cnhuhbert_model, + vits_model=self.vits_model, + device=self.configs.device, + is_half=self.configs.is_half, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + stage_limiter=self.prepare_ref_audio_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")), + max_batch_samples=int(ref_max_batch_samples), + ) + self.text_preprocessor: TextPreprocessor = TextPreprocessor( - self.bert_model, self.bert_tokenizer, self.configs.device + self.bert_model, + self.bert_tokenizer, + self.configs.device, + bert_stage_limiter=self.prepare_bert_stage_limiter, + bert_batch_worker=self.prepare_bert_batch_worker, ) self.prompt_cache: dict = { @@ -755,47 +806,52 @@ class TTS: Args: ref_audio_path: str, the path of the reference audio. """ - self._set_prompt_semantic(ref_audio_path) - self._set_ref_spec(ref_audio_path) + bundle = self.extract_ref_audio_bundle(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [bundle["refer_spec"]] + else: + self.prompt_cache["refer_spec"][0] = bundle["refer_spec"] + self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"] + self.prompt_cache["raw_audio"] = bundle["raw_audio"] + self.prompt_cache["raw_sr"] = bundle["raw_sr"] self._set_ref_audio_path(ref_audio_path) - def extract_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - return prompt_semantic - - def extract_ref_spec(self, ref_audio_path: str): + def _load_ref_audio_raw(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) - raw_audio = raw_audio.to(self.configs.device).float() + return raw_audio.float(), int(raw_sr) + + @torch.inference_mode() + def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor): + wav16k = wav16k.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) + codes = self.vits_model.extract_latent(hubert_feature) + return codes[0, 0].to(self.configs.device) + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + return self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + + def extract_prompt_semantic(self, ref_wav_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) + return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + + def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + raw_audio_device = raw_audio.to(self.configs.device).float() if raw_sr != self.configs.sampling_rate: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) else: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) @@ -820,8 +876,141 @@ class TTS: audio = None return spec, audio, raw_audio, raw_sr - def extract_text_features(self, text: str, language: str): - return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + def extract_ref_spec(self, ref_audio_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + return self._extract_ref_spec_from_raw(raw_audio, raw_sr) + + def extract_ref_audio_bundle(self, ref_audio_path: str): + load_start = time.perf_counter() + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + if self.prepare_ref_semantic_batch_worker is None: + with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: + prompt_semantic_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(limiter_stats["wait_ms"]) + audio_stage_slots = float(limiter_stats["slots"]) + audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) + prompt_semantic_profile = { + "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": prompt_semantic_ms, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + ref_spec_wait_ms = 0.0 + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0) + ), + "prompt_semantic_scatter_ms": float( + prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + "prompt_semantic_stage_slots": float( + prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0) + ), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": ref_spec_wait_ms, + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + prompt_semantic_profile = { + "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": 0.0, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": 0.0, + "prompt_semantic_stage_inflight_peak": 0.0, + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + if self.prepare_ref_semantic_batch_worker is not None: + prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + prompt_semantic_profile.update(worker_profile) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_limiter_stats["wait_ms"] + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_limiter_stats["slots"]), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_limiter_stats["peak_inflight"]), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + def extract_text_features(self, text: str, language: str, profile: dict | None = None): + return self.text_preprocessor.segment_and_extract_feature_for_text( + text, language, self.configs.version, profile=profile + ) def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 283e91c3..15b3c322 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,6 +1,8 @@ import os import sys import threading +import time +from contextlib import contextmanager from tqdm import tqdm @@ -16,6 +18,7 @@ from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker from tools.i18n.i18n import I18nAuto, scan_language_list @@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list: return result +class StageLimiter: + def __init__(self, slots: int): + self.slots = max(1, int(slots)) + self.semaphore = threading.BoundedSemaphore(self.slots) + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + + @contextmanager + def enter(self): + wait_start = time.perf_counter() + self.semaphore.acquire() + wait_ms = (time.perf_counter() - wait_start) * 1000.0 + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + peak_inflight = self.peak_inflight + try: + yield { + "wait_ms": wait_ms, + "inflight": current_inflight, + "peak_inflight": peak_inflight, + "slots": self.slots, + } + finally: + with self.lock: + self.inflight = max(0, self.inflight - 1) + self.semaphore.release() + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "slots": self.slots, + "inflight": self.inflight, + "peak_inflight": self.peak_inflight, + } + + class TextPreprocessor: - def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): + def __init__( + self, + bert_model: AutoModelForMaskedLM, + tokenizer: AutoTokenizer, + device: torch.device, + bert_stage_limiter: StageLimiter | None = None, + bert_batch_worker: PrepareBertBatchWorker | None = None, + ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device - self.bert_lock = threading.RLock() + self.bert_stage_limiter = bert_stage_limiter + self.bert_batch_worker = bert_batch_worker def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") @@ -115,86 +166,136 @@ class TextPreprocessor: return texts def segment_and_extract_feature_for_text( - self, text: str, language: str, version: str = "v1" + self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version) + return self.get_phones_and_bert(text, language, version, profile=profile) - def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): - with self.bert_lock: - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): + def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - # print(textlist) - # print(langlist) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True) + def get_phones_and_bert( + self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None + ): + text = re.sub(r' {2,}', ' ', text) + textlist, langlist = self._split_text_by_language(text, language) + phones_list = [] + bert_list = [] + norm_text_list = [] + for segment_text, segment_lang in zip(textlist, langlist): + phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version) + bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) - return phones, bert, norm_text + if not final and len(phones) < 6: + return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile) - def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: - with torch.no_grad(): - inputs = self.tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(self.device) - res = self.bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + return phones, bert, norm_text + + def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(profile.get(key, 0.0)) + float(value) + + def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: + if self.bert_batch_worker is not None: + feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak( + profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0) + ) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + return feature + + limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} + if self.bert_stage_limiter is None: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.bert_stage_limiter.enter() as limiter_stats: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"]) + self._accumulate_profile(profile, "bert_forward_ms", forward_ms) + self._accumulate_profile(profile, "bert_calls", 1.0) + self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"]) + if profile is not None: + profile["bert_stage_slots"] = float(limiter_stats["slots"]) assert len(word2ph) == len(text) phone_level_feature = [] for i in range(len(word2ph)): @@ -209,10 +310,10 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): + def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None): language = language.replace("all_", "") if language == "zh": - feature = self.get_bert_feature(norm_text, word2ph).to(self.device) + feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( (1024, len(phones)), @@ -236,4 +337,4 @@ class TextPreprocessor: punctuations = "".join(re.escape(p) for p in punctuation) pattern = f"([{punctuations}])([{punctuations}])+" result = re.sub(pattern, r"\1", text) - return result \ No newline at end of file + return result diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py new file mode 100644 index 00000000..b1ede3d8 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -0,0 +1,197 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import torch + + +@dataclass +class BertFeatureTask: + norm_text: str + word2ph: List[int] + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_feature: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareBertBatchWorker: + def __init__( + self, + bert_model, + tokenizer, + device, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 16, + max_batch_tokens: int = 4096, + ): + self.bert_model = bert_model + self.tokenizer = tokenizer + self.device = device + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_tokens = max(16, int(max_batch_tokens)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[BertFeatureTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) + self.worker_thread.start() + + def _estimate_task_tokens(self, task: BertFeatureTask) -> int: + return max(1, len(task.norm_text) + 2) + + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_feature is not None + return task.result_feature, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_tokens": self.max_batch_tokens, + } + + def _collect_batch(self) -> List[BertFeatureTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] + batch_tokens = self._estimate_task_tokens(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_tokens = self._estimate_task_tokens(next_task) + if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens: + break + batch.append(self.pending_tasks.popleft()) + batch_tokens += next_tokens + + self.active_batch_size = len(batch) + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + return batch + + def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _run_batch(self, batch: List[BertFeatureTask]) -> None: + batch_started = time.perf_counter() + texts = [task.norm_text for task in batch] + batch_tokens = sum(self._estimate_task_tokens(task) for task in batch) + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + hidden = outputs["hidden_states"][-3].detach().cpu() + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + text_len = len(task.word2ph) + if text_len != len(task.norm_text): + raise AssertionError( + f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}" + ) + seq_len = int(attention_mask_cpu[batch_index].sum().item()) + char_features = hidden[batch_index, 1 : seq_len - 1] + if char_features.shape[0] != text_len: + raise AssertionError( + f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}" + ) + phone_level_feature = [] + for char_index, repeat_count in enumerate(task.word2ph): + phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1)) + task.result_feature = torch.cat(phone_level_feature, dim=0).T + task.profile = { + "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_forward_ms": float(forward_ms), + "bert_tokenize_ms": float(tokenize_ms), + "bert_scatter_ms": 0.0, + "bert_calls": 1.0, + "bert_stage_slots": float(limiter_stats["slots"]), + "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "bert_batch_size": float(len(batch)), + "bert_batch_tokens": float(batch_tokens), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_feature is not None: + task.profile["bert_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py new file mode 100644 index 00000000..7a1f9a53 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -0,0 +1,262 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import librosa +import numpy as np +import torch + + +REF_AUDIO_MIN_SAMPLES_16K = 48000 +REF_AUDIO_MAX_SAMPLES_16K = 160000 + + +def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: + wav_mono = raw_audio + if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: + wav_mono = wav_mono.mean(0, keepdim=True) + wav16k = wav_mono.squeeze(0).cpu().numpy() + if raw_sr != 16000: + wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000) + if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: + raise OSError("参考音频在3~10秒范围外,请更换!") + wav16k = np.ascontiguousarray(wav16k, dtype=np.float32) + if zero_wav_samples > 0: + wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0) + return torch.from_numpy(wav16k) + + +def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: + if conv1d is None: + return input_lengths.to(dtype=torch.long) + kernel_size = int(conv1d.kernel_size[0]) + stride = int(conv1d.stride[0]) + padding = int(conv1d.padding[0]) + dilation = int(conv1d.dilation[0]) + output_lengths = torch.div( + input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1, + stride, + rounding_mode="floor", + ) + 1 + return torch.clamp(output_lengths, min=0).to(dtype=torch.long) + + +@dataclass +class RefSemanticTask: + raw_audio: torch.Tensor + raw_sr: int + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_prompt_semantic: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareRefSemanticBatchWorker: + def __init__( + self, + ssl_model, + vits_model, + device, + is_half: bool, + zero_wav_samples: int, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 8, + max_batch_samples: int = 960000, + ): + self.ssl_model = ssl_model + self.vits_model = vits_model + self.device = device + self.is_half = bool(is_half) + self.zero_wav_samples = max(0, int(zero_wav_samples)) + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[RefSemanticTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.active_batch_samples = 0 + self.active_batch_samples_peak = 0 + self.worker_thread = threading.Thread( + target=self._run_loop, + name="prepare-ref-semantic-batch-worker", + daemon=True, + ) + self.worker_thread.start() + + def _estimate_task_samples(self, task: RefSemanticTask) -> int: + raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0 + base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr)))) + return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples + + def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_prompt_semantic is not None + return task.result_prompt_semantic, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "active_batch_samples": self.active_batch_samples, + "active_batch_samples_peak": self.active_batch_samples_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_samples": self.max_batch_samples, + } + + def _collect_batch(self) -> List[RefSemanticTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] + batch_samples = self._estimate_task_samples(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_samples = self._estimate_task_samples(next_task) + if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: + break + batch.append(self.pending_tasks.popleft()) + batch_samples += next_samples + + self.active_batch_size = len(batch) + self.active_batch_samples = batch_samples + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + if self.active_batch_samples > self.active_batch_samples_peak: + self.active_batch_samples_peak = self.active_batch_samples + return batch + + def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.active_batch_samples = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor: + model = self.ssl_model.model + if hasattr(model, "_get_feature_vector_attention_mask"): + feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask) + return feature_mask.to(dtype=torch.long).sum(dim=1) + raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1) + if hasattr(model, "_get_feat_extract_output_lengths"): + return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long) + return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) + + @torch.inference_mode() + def _run_batch(self, batch: List[RefSemanticTask]) -> None: + batch_started = time.perf_counter() + prepared_start = time.perf_counter() + prepared_wavs = [ + prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch + ] + cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0 + wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long) + batch_samples = int(wav_lengths.sum().item()) + max_wav_len = int(wav_lengths.max().item()) + + input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) + attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) + for batch_index, wav in enumerate(prepared_wavs): + wav_len = int(wav.shape[0]) + input_values_cpu[batch_index, :wav_len] = wav + attention_mask_cpu[batch_index, :wav_len] = 1 + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + code_len = int(code_lengths[batch_index].item()) + task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + task.profile = { + "prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_calls": 1.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": float(len(batch)), + "prompt_semantic_batch_samples": float(batch_samples), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_prompt_semantic is not None: + task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index c8643991..de498573 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,5 +1,6 @@ from __future__ import annotations +from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path import time @@ -123,31 +124,58 @@ def prepare_request_state( prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) text = spec.text.strip("\n") - _sync_device(device) - prompt_text_features_start = time.perf_counter() - prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) - _sync_device(device) - prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile: Dict[str, float] = {} + text_features_profile: Dict[str, float] = {} + text_feature_pair_start = time.perf_counter() + prompt_future: Future | None = None + + def _extract_prompt_features(): + _sync_device(device) + prompt_start = time.perf_counter() + result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile) + _sync_device(device) + return result, (time.perf_counter() - prompt_start) * 1000.0 + + if getattr(tts, "prepare_text_cpu_executor", None) is not None: + prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features) _sync_device(device) text_features_start = time.perf_counter() - phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile) _sync_device(device) text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + + if prompt_future is None: + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features( + prompt_text, spec.prompt_lang, profile=prompt_text_profile + ) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile["parallel_future_wait_ms"] = 0.0 + else: + prompt_wait_start = time.perf_counter() + (prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result() + prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0 + + text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0 if phones is None: raise ValueError(f"{spec.request_id} text preprocessing returned no phones") _sync_device(device) - prompt_semantic_start = time.perf_counter() - prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + ref_audio_bundle_start = time.perf_counter() + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + prompt_semantic = ref_audio_bundle["prompt_semantic"].long() + spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + raw_audio = ref_audio_bundle["raw_audio"] + raw_sr = int(ref_audio_bundle["raw_sr"]) _sync_device(device) - prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 - - _sync_device(device) - ref_spec_start = time.perf_counter() - spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) - _sync_device(device) - ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0 + bundle_profile = ref_audio_bundle.get("profile", {}) + prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) + ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) + audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) _sync_device(device) tensorize_start = time.perf_counter() @@ -164,8 +192,43 @@ def prepare_request_state( prepare_profile = { "prompt_text_features_ms": prompt_text_features_ms, "text_features_ms": text_features_ms, + "prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)), + "text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)), + "text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)), + "text_feature_pair_ms": text_feature_pair_ms, + "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), + "audio_load_ms": audio_load_ms, + "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), + "audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)), + "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_ms": ref_spec_ms, + "ref_audio_bundle_ms": ref_audio_bundle_ms, "tensorize_ms": tensorize_ms, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, @@ -186,7 +249,7 @@ def prepare_request_state( prompt_semantic=prompt_semantic, refer_spec=(spec_audio, audio_16k), raw_audio=raw_audio, - raw_sr=int(raw_sr), + raw_sr=raw_sr, top_k=spec.top_k, top_p=spec.top_p, temperature=spec.temperature, @@ -409,9 +472,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: return F.pad(mask, (pad_len, 0), value=True) +def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor: + if mask.shape[-1] > target_len: + return mask[:, :, :, -target_len:] + if mask.shape[-1] < target_len: + return _pad_decode_mask_left(mask, target_len) + return mask + + def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + expected_mask_len = running_request.k_cache[0].shape[1] + 1 if running_request.decode_attn_mask is not None: - return running_request.decode_attn_mask + return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len) current_mask_len = running_request.k_cache[0].shape[1] + 1 return torch.zeros( (1, 1, 1, current_mask_len), @@ -481,17 +553,19 @@ def run_prefill_step( real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + request_decode_attn_mask = None + if decode_attn_mask is not None: + request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone() + request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1) + if not request_decode_attn_mask.any().item(): + request_decode_attn_mask = None running_requests.append( T2SRunningRequest( state=state, y_sequence=new_history, prefix_len=prefix_len, - decode_attn_mask=( - None - if decode_attn_mask is None - else decode_attn_mask[batch_index : batch_index + 1].clone() - ), + decode_attn_mask=request_decode_attn_mask, k_cache=request_k_cache, v_cache=request_v_cache, step_idx=1, @@ -603,6 +677,9 @@ def run_decode_step_for_running( batch_index : batch_index + 1, :, :, -current_decode_mask_len: ] next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) + next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1) + if not next_decode_attn_mask.any().item(): + next_decode_attn_mask = None next_running.append( T2SRunningRequest( state=running_request.state, diff --git a/api_v3.py b/api_v3.py index 9d250119..92f9a3b9 100644 --- a/api_v3.py +++ b/api_v3.py @@ -261,8 +261,9 @@ class SchedulerDebugWorker: self.tts = tts self.max_steps = max_steps self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 - self.prepare_lock = threading.Lock() self.condition = threading.Condition() + self.prepare_inflight = 0 + self.prepare_peak_inflight = 0 self.pending_jobs: List[SchedulerPendingJob] = [] self.running_requests: List[T2SRunningRequest] = [] self.job_map: dict[str, SchedulerPendingJob] = {} @@ -282,8 +283,20 @@ class SchedulerDebugWorker: pass def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - with self.prepare_lock: - return prepare_request_state(self.tts, spec) + with self.condition: + self.prepare_inflight += 1 + prepare_inflight_on_enter = self.prepare_inflight + if self.prepare_inflight > self.prepare_peak_inflight: + self.prepare_peak_inflight = self.prepare_inflight + prepare_peak_inflight = self.prepare_peak_inflight + try: + state = prepare_request_state(self.tts, spec) + state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter) + state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight) + return state + finally: + with self.condition: + self.prepare_inflight = max(0, self.prepare_inflight - 1) def submit( self, @@ -363,9 +376,28 @@ class SchedulerDebugWorker: def get_state(self) -> dict: with self.condition: + bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() + ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() + bert_batch_worker = ( + None + if self.tts.prepare_bert_batch_worker is None + else self.tts.prepare_bert_batch_worker.snapshot() + ) + ref_semantic_batch_worker = ( + None + if self.tts.prepare_ref_semantic_batch_worker is None + else self.tts.prepare_ref_semantic_batch_worker.snapshot() + ) return { "pending_jobs": len(self.pending_jobs), "running_requests": len(self.running_requests), + "prepare_inflight": self.prepare_inflight, + "prepare_peak_inflight": self.prepare_peak_inflight, + "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), + "prepare_bert_stage": bert_stage, + "prepare_bert_batch_worker": bert_batch_worker, + "prepare_ref_audio_stage": ref_audio_stage, + "prepare_ref_semantic_batch_worker": ref_semantic_batch_worker, "tracked_jobs": len(self.job_map), "total_submitted": self.total_submitted, "total_finished": self.total_finished, @@ -907,10 +939,36 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): { "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", + "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Batch-Size": str( + int(prepare_profile.get("prompt_semantic_batch_size", 0.0)) + ), "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", + "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + "X-Prepare-Inflight-On-Enter": str( + int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) + ), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), } ) return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) From a45e171ff50a372f1ac1ce567fd95659050b453b Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 21:24:16 +0800 Subject: [PATCH 7/9] Enhance sampling functions in TTS by adding support for previous token masks in logits_to_probs. Implement batch processing for sampling with padded token sequences and contiguous sampling groups. Refactor sampling logic in T2S scheduler to utilize new functionalities, improving efficiency and flexibility in token generation. --- GPT_SoVITS/AR/models/utils.py | 37 ++++- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 170 ++++++++++++++++----- 2 files changed, 165 insertions(+), 42 deletions(-) diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index cc4f24d8..4b564ed8 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync( def logits_to_probs( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[int] = None, @@ -158,13 +159,27 @@ def logits_to_probs( # pdb.set_trace() if previous_tokens is not None and repetition_penalty != 1.0: previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=1, index=previous_tokens) - score = torch.where( - score < 0, - score * repetition_penalty, - score / repetition_penalty, - ) - logits.scatter_(dim=1, index=previous_tokens, src=score) + if previous_token_mask is None: + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=1, index=previous_tokens, src=score) + else: + previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device) + if previous_token_mask.any(): + batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens) + valid_batch_index = batch_index[previous_token_mask] + valid_token_index = previous_tokens[previous_token_mask] + score = logits[valid_batch_index, valid_token_index] + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits[valid_batch_index, valid_token_index] = score if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) @@ -192,9 +207,15 @@ def logits_to_probs( def sample( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, **sampling_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs) + probs = logits_to_probs( + logits=logits, + previous_tokens=previous_tokens, + previous_token_mask=previous_token_mask, + **sampling_kwargs, + ) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index de498573..b7118a72 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import torch import torch.nn.functional as F -from AR.models.utils import make_pad_mask_left, sample +from AR.models.utils import logits_to_probs, make_pad_mask_left, multinomial_sample_one_no_sync, sample def _sync_device(device: Any) -> None: @@ -277,6 +277,90 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: ) +def _pad_token_sequences( + token_sequences: Sequence[torch.LongTensor], +) -> Tuple[torch.LongTensor, torch.BoolTensor]: + if not token_sequences: + raise ValueError("token_sequences 不能为空") + device = token_sequences[0].device + max_len = max(int(sequence.shape[0]) for sequence in token_sequences) + padded = torch.zeros((len(token_sequences), max_len), dtype=token_sequences[0].dtype, device=device) + mask = torch.zeros((len(token_sequences), max_len), dtype=torch.bool, device=device) + for row_index, sequence in enumerate(token_sequences): + seq_len = int(sequence.shape[0]) + padded[row_index, :seq_len] = sequence + mask[row_index, :seq_len] = True + return padded, mask + + +def _sampling_group_key( + top_k: int, + top_p: float, + temperature: float, + repetition_penalty: float, + trim_eos: bool, +) -> Tuple[int, float, float, float, bool]: + return ( + int(top_k), + float(top_p), + float(temperature), + float(repetition_penalty), + bool(trim_eos), + ) + + +def _iter_contiguous_sampling_groups( + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> List[Tuple[Tuple[int, float, float, float, bool], List[int]]]: + groups: List[Tuple[Tuple[int, float, float, float, bool], List[int]]] = [] + if not sampling_keys: + return groups + current_key = sampling_keys[0] + current_indices: List[int] = [0] + for index in range(1, len(sampling_keys)): + key = sampling_keys[index] + if key == current_key: + current_indices.append(index) + continue + groups.append((current_key, current_indices)) + current_key = key + current_indices = [index] + groups.append((current_key, current_indices)) + return groups + + +def _batched_sample_by_group( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> Tuple[List[torch.Tensor], List[int]]: + sampled_list: List[Optional[torch.Tensor]] = [None] * len(histories) + argmax_list: List[Optional[int]] = [None] * len(histories) + for group_key, group_indices in _iter_contiguous_sampling_groups(sampling_keys): + top_k, top_p, temperature, repetition_penalty, trim_eos = group_key + index_tensor = torch.tensor(group_indices, dtype=torch.long, device=logits.device) + group_logits = torch.index_select(logits, dim=0, index=index_tensor) + if trim_eos: + group_logits = group_logits[:, :-1] + group_histories = [histories[index] for index in group_indices] + padded_histories, history_mask = _pad_token_sequences(group_histories) + probs = logits_to_probs( + logits=group_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + argmax_tokens = torch.argmax(group_logits, dim=-1) + for local_index, global_index in enumerate(group_indices): + sampled_list[global_index] = multinomial_sample_one_no_sync(probs[local_index : local_index + 1]) + argmax_list[global_index] = int(argmax_tokens[local_index].item()) + + return [item for item in sampled_list if item is not None], [int(item) for item in argmax_list if item is not None] + + @torch.inference_mode() def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: x_items: List[torch.Tensor] = [] @@ -360,19 +444,26 @@ def _sample_per_request( updated_sequences: List[torch.LongTensor] = [] step_idx = active_batch.step_idx - for batch_index, state in enumerate(active_batch.states): - logits_i = logits[batch_index : batch_index + 1].clone() - current_history = active_batch.y_sequences[batch_index] - sampled = sample( - logits_i, - current_history.unsqueeze(0), + sampling_keys = [ + _sampling_group_key( top_k=state.top_k, top_p=state.top_p, - repetition_penalty=state.repetition_penalty, temperature=state.temperature, - )[0] + repetition_penalty=state.repetition_penalty, + trim_eos=False, + ) + for state in active_batch.states + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) + for batch_index, state in enumerate(active_batch.states): + current_history = active_batch.y_sequences[batch_index] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None @@ -507,25 +598,30 @@ def run_prefill_step( if len(states) == 1 and not decode_attn_mask.any().item(): decode_attn_mask = None logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=True, + ) + for state in states + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) running_requests: List[T2SRunningRequest] = [] finished_items: List[T2SFinishedItem] = [] for batch_index, state in enumerate(states): - logits_i = logits[batch_index : batch_index + 1].clone() - if 0 < 11: - logits_i = logits_i[:, :-1] current_history = active_batch.y_sequences[batch_index] - sampled = sample( - logits_i, - current_history.unsqueeze(0), - top_k=state.top_k, - top_p=state.top_p, - repetition_penalty=state.repetition_penalty, - temperature=state.temperature, - )[0] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([current_history, sampled.view(-1)], dim=0) prefix_len = int(active_batch.prefix_lens[batch_index].item()) @@ -624,25 +720,31 @@ def run_decode_step_for_running( batched_decode_attn_mask, ) logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + temperature=running_request.state.temperature, + repetition_penalty=running_request.state.repetition_penalty, + trim_eos=running_request.step_idx < 11, + ) + for running_request in running_requests + ] + histories = [running_request.y_sequence for running_request in running_requests] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=histories, + sampling_keys=sampling_keys, + ) next_running: List[T2SRunningRequest] = [] finished_items: List[T2SFinishedItem] = [] for batch_index, running_request in enumerate(running_requests): current_idx = running_request.step_idx - logits_i = logits[batch_index : batch_index + 1].clone() - if current_idx < 11: - logits_i = logits_i[:, :-1] - sampled = sample( - logits_i, - running_request.y_sequence.unsqueeze(0), - top_k=running_request.state.top_k, - top_p=running_request.state.top_p, - repetition_penalty=running_request.state.repetition_penalty, - temperature=running_request.state.temperature, - )[0] + sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) - argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + argmax_token = argmax_tokens[batch_index] new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) finish_reason: Optional[str] = None From 827d6ea47c5d03aac006597f2f8e41761b284b13 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Tue, 10 Mar 2026 06:58:53 +0800 Subject: [PATCH 8/9] Refactor TTS and scheduler components to enhance text processing and batching capabilities. Introduce PrepareCoordinator for managing text feature preparation asynchronously, and update SchedulerDebugWorker to support new finalize task management. Implement batch processing in PrepareBertBatchWorker with improved admission control and profiling metrics. Add text CPU preprocessing utilities for better text segmentation and normalization. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 160 ++++- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 232 ++++++- .../prepare_bert_batch_worker.py | 181 +++++- .../TTS_infer_pack/prepare_coordinator.py | 294 +++++++++ GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 489 +++++++++++---- .../TTS_infer_pack/text_cpu_preprocess.py | 100 ++++ GPT_SoVITS/module/models.py | 62 ++ api_v3.py | 565 +++++++++++++++--- 8 files changed, 1811 insertions(+), 272 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py create mode 100644 GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c259662..d475b804 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,26 +1,27 @@ import gc +import concurrent.futures import math import os import random import sys import time import traceback -from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -import torchaudio -from tqdm import tqdm - now_dir = os.getcwd() sys.path.append(now_dir) -import os from typing import List, Tuple, Union +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + import ffmpeg import librosa import numpy as np import torch import torch.nn.functional as F +import torchaudio import yaml from AR.models.t2s_lightning_module import Text2SemanticLightningModule from BigVGAN.bigvgan import BigVGAN @@ -30,6 +31,7 @@ from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator from peft import LoraConfig, get_peft_model from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from transformers import AutoModelForMaskedLM, AutoTokenizer +from tqdm import tqdm from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list @@ -449,20 +451,21 @@ class TTS: "overlapped_len": None, } self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) - self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) self.prepare_bert_batch_worker = None self.prepare_ref_semantic_batch_worker = None - default_text_cpu_workers = 16 self.prepare_text_cpu_workers = max( 0, - int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))), + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")), ) - self.prepare_text_cpu_executor = None - if self.prepare_text_cpu_workers > 0: - self.prepare_text_cpu_executor = ThreadPoolExecutor( + self.prepare_text_cpu_executor = ( + concurrent.futures.ThreadPoolExecutor( max_workers=self.prepare_text_cpu_workers, thread_name_prefix="prepare-text-cpu", ) + if self.prepare_text_cpu_workers > 0 + else None + ) self._init_models() @@ -475,6 +478,20 @@ class TTS: batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_BERT_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_ADMISSION_POLL_MS", "1")), + high_pressure_pending_threshold=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "0") + ), + high_pressure_batch_window_ms=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_BATCH_WINDOW_MS", "1") + ), + high_pressure_max_batch_items=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_ITEMS", "32") + ), + high_pressure_max_batch_tokens=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_TOKENS", "8192") + ), ) if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") @@ -830,13 +847,23 @@ class TTS: return codes[0, 0].to(self.configs.device) @torch.inference_mode() - def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + cpu_prepare_start = time.perf_counter() wav16k = prepare_prompt_semantic_wav16k( raw_audio=raw_audio, raw_sr=raw_sr, zero_wav_samples=int(self.configs.sampling_rate * 0.3), ) - return self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + forward_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + return prompt_semantic, cpu_prepare_ms, forward_ms + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + prompt_semantic, _, _ = self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + return prompt_semantic def extract_prompt_semantic(self, ref_wav_path: str): raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) @@ -887,7 +914,9 @@ class TTS: if self.prepare_ref_semantic_batch_worker is None: with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: prompt_semantic_start = time.perf_counter() - prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + prompt_semantic, prompt_semantic_cpu_prepare_ms, prompt_semantic_forward_ms = ( + self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + ) prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 ref_spec_start = time.perf_counter() refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] @@ -897,8 +926,8 @@ class TTS: audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) prompt_semantic_profile = { "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), - "prompt_semantic_cpu_prepare_ms": 0.0, - "prompt_semantic_forward_ms": prompt_semantic_ms, + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), + "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), "prompt_semantic_scatter_ms": 0.0, "prompt_semantic_stage_slots": float(limiter_stats["slots"]), "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), @@ -1012,6 +1041,32 @@ class TTS: text, language, self.configs.version, profile=profile ) + def prepare_text_segments(self, text: str, language: str): + return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version) + + def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None): + return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile) + + async def build_text_features_from_segments_async(self, prepared_segments, profile: dict | None = None): + return await self.text_preprocessor.build_phones_and_bert_from_segments_async( + prepared_segments, + profile=profile, + ) + + async def build_text_feature_pair_from_segments_async( + self, + prompt_segments, + target_segments, + prompt_profile: dict | None = None, + target_profile: dict | None = None, + ): + return await self.text_preprocessor.build_phones_and_bert_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path @@ -2011,6 +2066,79 @@ class TTS: sample_steps=sample_steps, ) + @torch.inference_mode() + def synthesize_audio_requests_local_batched( + self, + semantic_tokens_list: List[torch.Tensor], + phones_list: List[torch.Tensor], + refer_specs: List[tuple], + speeds: List[float] | None = None, + sample_steps_list: List[int] | None = None, + ) -> List[torch.Tensor]: + batch_size = len(semantic_tokens_list) + if batch_size == 0: + return [] + if len(phones_list) != batch_size or len(refer_specs) != batch_size: + raise ValueError("batched request-local synthesis 输入长度不一致") + if speeds is None: + speeds = [1.0] * batch_size + if sample_steps_list is None: + sample_steps_list = [32] * batch_size + if len(speeds) != batch_size or len(sample_steps_list) != batch_size: + raise ValueError("batched request-local synthesis 参数长度不一致") + first_speed = float(speeds[0]) + first_sample_steps = int(sample_steps_list[0]) + if any(abs(float(item) - first_speed) > 1e-6 for item in speeds): + raise ValueError("batched request-local synthesis 目前要求 speed 一致") + if any(int(item) != first_sample_steps for item in sample_steps_list): + raise ValueError("batched request-local synthesis 目前要求 sample_steps 一致") + if self.configs.use_vocoder: + raise NotImplementedError("request-local batched VITS synthesis 暂不支持 vocoder 模型") + + device = self.configs.device + max_semantic_len = max(int(item.shape[-1]) for item in semantic_tokens_list) + max_phone_len = max(int(item.shape[-1]) for item in phones_list) + semantic_batch = torch.zeros((1, batch_size, max_semantic_len), dtype=torch.long, device=device) + phone_batch = torch.zeros((batch_size, max_phone_len), dtype=torch.long, device=device) + semantic_lengths = [] + phone_lengths = [] + refer_audio_specs: List[torch.Tensor] = [] + sv_emb_batch = None + sv_emb_list: List[torch.Tensor] = [] + + for batch_index, semantic_tokens in enumerate(semantic_tokens_list): + semantic_len = int(semantic_tokens.shape[-1]) + phone_len = int(phones_list[batch_index].shape[-1]) + semantic_batch[0, batch_index, :semantic_len] = semantic_tokens.to(device=device, dtype=torch.long) + phone_batch[batch_index, :phone_len] = phones_list[batch_index].to(device=device, dtype=torch.long) + semantic_lengths.append(semantic_len) + phone_lengths.append(phone_len) + + refer_audio_spec, audio_tensor = refer_specs[batch_index] + refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device)) + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local batched synthesis 缺少 16k 参考音频")) + sv_emb_list.append(self.sv_model.compute_embedding3(audio_tensor).to(device)) + + if self.is_v2pro: + sv_emb_batch = torch.cat(sv_emb_list, dim=0) + + audio_batch, audio_lengths = self.vits_model.decode_batched_request_local( + codes=semantic_batch, + code_lengths=torch.LongTensor(semantic_lengths).to(device), + text=phone_batch, + text_lengths=torch.LongTensor(phone_lengths).to(device), + refer_list=refer_audio_specs, + speed=first_speed, + sv_emb=sv_emb_batch, + ) + audios: List[torch.Tensor] = [] + for batch_index in range(batch_size): + audio_len = int(audio_lengths[batch_index].item()) + audios.append(audio_batch[batch_index, 0, :audio_len].detach()) + return audios + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 15b3c322..6bee49be 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,8 +1,10 @@ +import asyncio import os import sys import threading import time from contextlib import contextmanager +from dataclasses import dataclass from tqdm import tqdm @@ -13,12 +15,13 @@ import re import torch from text.LangSegmenter import LangSegmenter from text import chinese -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload from tools.i18n.i18n import I18nAuto, scan_language_list @@ -92,6 +95,14 @@ class StageLimiter: } +@dataclass +class PreparedTextSegment: + language: str + phones: List[int] + word2ph: Optional[List[int]] + norm_text: str + + class TextPreprocessor: def __init__( self, @@ -149,7 +160,7 @@ class TextPreprocessor: # 解决输入目标文本的空行导致报错的问题 if len(text.strip()) == 0: continue - if not re.sub("\W+", "", text): + if not re.sub(r"\W+", "", text): # 检测一下,如果是纯符号,就跳过。 continue if text[-1] not in splits: @@ -168,7 +179,8 @@ class TextPreprocessor: def segment_and_extract_feature_for_text( self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version, profile=profile) + prepared_segments = self.preprocess_text_segments(text, language, version) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: textlist = [] @@ -223,24 +235,49 @@ class TextPreprocessor: def get_phones_and_bert( self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None ): - text = re.sub(r' {2,}', ' ', text) - textlist, langlist = self._split_text_by_language(text, language) - phones_list = [] - bert_list = [] - norm_text_list = [] - for segment_text, segment_lang in zip(textlist, langlist): - phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile) - phones_list.append(phones) - norm_text_list.append(norm_text) + prepared_segments = self.preprocess_text_segments(text, language, version, final=final) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) + + def preprocess_text_segments( + self, + text: str, + language: str, + version: str, + final: bool = False, + ) -> List[PreparedTextSegment]: + payloads = preprocess_text_segments_payload(text, language, version, final=final) + return [ + PreparedTextSegment( + language=str(payload["language"]), + phones=list(payload["phones"]), + word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]), + norm_text=str(payload["norm_text"]), + ) + for payload in payloads + ] + + def build_phones_and_bert_from_segments( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor] = [] + norm_text_list: List[str] = [] + for segment in prepared_segments: + bert = self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) bert_list.append(bert) bert = torch.cat(bert_list, dim=1) phones = sum(phones_list, []) norm_text = "".join(norm_text_list) - - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile) - return phones, bert, norm_text def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: @@ -253,21 +290,41 @@ class TextPreprocessor: return profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None: + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0)) + self._accumulate_profile( + profile, + "bert_batch_collect_wait_ms", + worker_profile.get("bert_batch_collect_wait_ms", 0.0), + ) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + self._update_profile_peak( + profile, + "bert_pending_depth_on_enqueue_peak", + worker_profile.get("bert_pending_depth_on_enqueue", 0.0), + ) + self._update_profile_peak( + profile, + "bert_pending_depth_on_collect_peak", + worker_profile.get("bert_pending_depth_on_collect", 0.0), + ) + self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0)) + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: if self.bert_batch_worker is not None: feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) - self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) - self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) - self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) - self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) - self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) - self._update_profile_peak( - profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0) - ) - self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) - self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) - if profile is not None: - profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + self._merge_bert_worker_profile(profile, worker_profile) return feature limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} @@ -310,9 +367,18 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None): + def get_bert_inf( + self, + phones: list, + word2ph: Optional[list], + norm_text: str, + language: str, + profile: Dict | None = None, + ): language = language.replace("all_", "") if language == "zh": + if word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( @@ -322,6 +388,112 @@ class TextPreprocessor: return feature + async def build_phones_and_bert_from_segments_async( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + segment_jobs = self._build_async_segment_jobs(prepared_segments, profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(item_profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(segment_jobs) + + def _build_async_segment_jobs( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None, + ) -> Dict[str, List]: + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor | None] = [] + norm_text_list: List[str] = [] + + for segment in prepared_segments: + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) + segment_language = segment.language.replace("all_", "") + if segment_language == "zh" and self.bert_batch_worker is not None: + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + bert_list.append(None) + continue + bert_list.append( + self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + ) + return { + "phones_list": phones_list, + "bert_list": bert_list, + "norm_text_list": norm_text_list, + } + + @staticmethod + def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]: + bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1) + phones = sum(segment_jobs["phones_list"], []) + norm_text = "".join(segment_jobs["norm_text_list"]) + return phones, bert, norm_text + + async def build_phones_and_bert_pair_from_segments_async( + self, + prompt_segments: List[PreparedTextSegment], + target_segments: List[PreparedTextSegment], + prompt_profile: Dict | None = None, + target_profile: Dict | None = None, + ) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]: + prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile) + target_jobs = self._build_async_segment_jobs(target_segments, target_profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + + for segment_jobs, prepared_segments, profile in ( + (prompt_jobs, prompt_segments, prompt_profile), + (target_jobs, target_segments, target_profile), + ): + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs) + def filter_text(self, texts): _text = [] if all(text in [None, " ", "\n", ""] for text in texts): diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py index b1ede3d8..1ac77faa 100644 --- a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -1,3 +1,4 @@ +import asyncio import threading import time import uuid @@ -14,7 +15,12 @@ class BertFeatureTask: word2ph: List[int] task_id: str = field(default_factory=lambda: uuid.uuid4().hex) created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None result_feature: torch.Tensor | None = None error: Exception | None = None profile: Dict[str, float] = field(default_factory=dict) @@ -30,14 +36,37 @@ class PrepareBertBatchWorker: batch_window_ms: int = 5, max_batch_items: int = 16, max_batch_tokens: int = 4096, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + high_pressure_pending_threshold: int = 0, + high_pressure_batch_window_ms: int | None = None, + high_pressure_max_batch_items: int | None = None, + high_pressure_max_batch_tokens: int | None = None, ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device self.stage_limiter = stage_limiter - self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.batch_window_ms = max(0, int(batch_window_ms)) + self.batch_window_s = float(self.batch_window_ms) / 1000.0 self.max_batch_items = max(1, int(max_batch_items)) self.max_batch_tokens = max(16, int(max_batch_tokens)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + + self.high_pressure_pending_threshold = max( + 0, + int(high_pressure_pending_threshold) + if int(high_pressure_pending_threshold) > 0 + else max(self.max_batch_items * 2, 32), + ) + hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms) + hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items) + hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens) + self.high_pressure_batch_window_ms = max(0, hp_window_ms) + self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0 + self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items) + self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens) self.condition = threading.Condition() self.pending_tasks: Deque[BertFeatureTask] = deque() @@ -47,26 +76,70 @@ class PrepareBertBatchWorker: self.total_batches = 0 self.active_batch_size = 0 self.active_batch_peak = 0 + self.active_batch_tokens = 0 + self.active_batch_tokens_peak = 0 + self.high_pressure_batches = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) self.worker_thread.start() def _estimate_task_tokens(self, task: BertFeatureTask) -> int: return max(1, len(task.norm_text) + 2) + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks + + def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + def _enqueue_task(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + with self.condition: + while not self._can_enqueue_locked(): + self.condition.wait(timeout=self.admission_poll_s) + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + + async def _enqueue_task_async(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + while True: + with self.condition: + if self._can_enqueue_locked(): + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + return + await asyncio.sleep(self.admission_poll_s) + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) - with self.condition: - self.pending_tasks.append(task) - self.total_submitted += 1 - if len(self.pending_tasks) > self.pending_peak: - self.pending_peak = len(self.pending_tasks) - self.condition.notify_all() + self._enqueue_task(task) task.done_event.wait() if task.error is not None: raise task.error assert task.result_feature is not None return task.result_feature, dict(task.profile) + async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = BertFeatureTask( + norm_text=str(norm_text), + word2ph=list(word2ph), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + def snapshot(self) -> Dict[str, int]: with self.condition: return { @@ -77,21 +150,57 @@ class PrepareBertBatchWorker: "total_batches": self.total_batches, "active_batch_size": self.active_batch_size, "active_batch_peak": self.active_batch_peak, + "active_batch_tokens": self.active_batch_tokens, + "active_batch_tokens_peak": self.active_batch_tokens_peak, "batch_window_ms": int(self.batch_window_s * 1000.0), "max_batch_items": self.max_batch_items, "max_batch_tokens": self.max_batch_tokens, + "max_pending_tasks": self.max_pending_tasks, + "high_pressure_pending_threshold": self.high_pressure_pending_threshold, + "high_pressure_batch_window_ms": self.high_pressure_batch_window_ms, + "high_pressure_max_batch_items": self.high_pressure_max_batch_items, + "high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens, + "high_pressure_batches": self.high_pressure_batches, + "admission_wait_total_ms": self.admission_wait_total_ms, + "admission_wait_peak_ms": self.admission_wait_peak_ms, } - def _collect_batch(self) -> List[BertFeatureTask]: + def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]: + pending_depth = len(self.pending_tasks) + use_high_pressure = ( + self.high_pressure_pending_threshold > 0 + and pending_depth >= self.high_pressure_pending_threshold + ) + if use_high_pressure: + return ( + self.high_pressure_batch_window_s, + self.high_pressure_max_batch_items, + self.high_pressure_max_batch_tokens, + True, + pending_depth, + ) + return ( + self.batch_window_s, + self.max_batch_items, + self.max_batch_tokens, + False, + pending_depth, + ) + + def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]: with self.condition: while not self.pending_tasks: self.condition.wait() + collect_started = time.perf_counter() + batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = ( + self._select_batch_policy_locked() + ) batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] batch_tokens = self._estimate_task_tokens(batch[0]) - deadline = time.perf_counter() + self.batch_window_s + deadline = time.perf_counter() + batch_window_s - while len(batch) < self.max_batch_items: + while len(batch) < max_batch_items: remaining = deadline - time.perf_counter() if remaining <= 0: break @@ -100,26 +209,39 @@ class PrepareBertBatchWorker: continue next_task = self.pending_tasks[0] next_tokens = self._estimate_task_tokens(next_task) - if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens: + if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens: break batch.append(self.pending_tasks.popleft()) batch_tokens += next_tokens self.active_batch_size = len(batch) + self.active_batch_tokens = batch_tokens if self.active_batch_size > self.active_batch_peak: self.active_batch_peak = self.active_batch_size - return batch + if self.active_batch_tokens > self.active_batch_tokens_peak: + self.active_batch_tokens_peak = self.active_batch_tokens + if use_high_pressure: + self.high_pressure_batches += 1 + return batch, { + "collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0, + "batch_tokens": float(batch_tokens), + "pending_depth_on_collect": float(pending_depth_on_collect), + "high_pressure_mode": 1.0 if use_high_pressure else 0.0, + "batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms), + } def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: with self.condition: self.active_batch_size = 0 + self.active_batch_tokens = 0 self.total_batches += 1 self.total_finished += len(batch) + self.condition.notify_all() - def _run_batch(self, batch: List[BertFeatureTask]) -> None: + def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None: batch_started = time.perf_counter() texts = [task.norm_text for task in batch] - batch_tokens = sum(self._estimate_task_tokens(task) for task in batch) + batch_tokens = int(batch_meta["batch_tokens"]) limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} if self.stage_limiter is None: @@ -167,6 +289,9 @@ class PrepareBertBatchWorker: task.result_feature = torch.cat(phone_level_feature, dim=0).T task.profile = { "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_admission_wait_ms": float(task.admission_wait_ms), + "bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0), + "bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]), "bert_forward_ms": float(forward_ms), "bert_tokenize_ms": float(tokenize_ms), "bert_scatter_ms": 0.0, @@ -175,6 +300,10 @@ class PrepareBertBatchWorker: "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), "bert_batch_size": float(len(batch)), "bert_batch_tokens": float(batch_tokens), + "bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]), + "bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]), + "bert_batch_window_ms": float(batch_meta["batch_window_ms"]), } except Exception as exc: # noqa: PERF203 task.error = exc @@ -183,15 +312,35 @@ class PrepareBertBatchWorker: if task.result_feature is not None: task.profile["bert_scatter_ms"] = float(scatter_ms) task.done_event.set() + self._notify_done_future(task) + + @staticmethod + def _resolve_done_future(task: BertFeatureTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_feature is not None + task.done_future.set_result((task.result_feature, dict(task.profile))) + + def _notify_done_future(self, task: BertFeatureTask) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass def _run_loop(self) -> None: while True: - batch = self._collect_batch() + batch, batch_meta = self._collect_batch() try: - self._run_batch(batch) + self._run_batch(batch, batch_meta) except Exception as exc: # noqa: PERF203 for task in batch: task.error = exc task.done_event.set() + self._notify_done_future(task) finally: self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py new file mode 100644 index 00000000..1fdf95c5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import os +import threading +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + PreparedTextFeatures, + SchedulerRequestSpec, + T2SRequestState, + build_request_state_from_parts, + normalize_sentence, +) + + +@dataclass +class ProfiledResult: + result: Any + submit_at: float + started_at: float + finished_at: float + + @property + def queue_ms(self) -> float: + return max(0.0, (self.started_at - self.submit_at) * 1000.0) + + @property + def run_ms(self) -> float: + return max(0.0, (self.finished_at - self.started_at) * 1000.0) + + +class PrepareCoordinator: + def __init__(self, tts: Any): + self.tts = tts + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + self.use_async_text_feature_path = bool( + getattr(tts, "prepare_bert_batch_worker", None) is not None + and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0" + ) + self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0"))) + self._inflight_semaphore = asyncio.Semaphore(self.max_inflight) if self.max_inflight > 0 else None + self.text_feature_workers = 0 + self.text_feature_executor = None + if not self.use_async_text_feature_path: + text_feature_default_workers = max(1, int(getattr(tts, "prepare_text_cpu_workers", 16) or 16)) + self.text_feature_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_WORKERS", str(text_feature_default_workers))), + ) + self.text_feature_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.text_feature_workers, + thread_name_prefix="prepare-text-feature", + ) + ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) + self.ref_audio_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_REF_ASYNC_WORKERS", str(ref_audio_default_workers))), + ) + self.ref_audio_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.ref_audio_workers, + thread_name_prefix="prepare-ref-audio", + ) + + def _mark_enter(self) -> Tuple[int, int]: + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + return current_inflight, self.peak_inflight + + def _mark_leave(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "inflight": int(self.inflight), + "peak_inflight": int(self.peak_inflight), + "max_inflight": int(self.max_inflight), + "text_feature_workers": int(self.text_feature_workers), + "ref_audio_workers": int(self.ref_audio_workers), + } + + @staticmethod + def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult: + started_at = time.perf_counter() + result = fn(*args) + finished_at = time.perf_counter() + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=float(started_at), + finished_at=float(finished_at), + ) + + def _prepare_text_cpu(self, text: str, language: str): + return self.tts.prepare_text_segments(text, language) + + def _build_text_features(self, prepared_segments, language: str, cpu_run_ms: float) -> PreparedTextFeatures: + profile: Dict[str, float] = {"cpu_preprocess_ms": float(cpu_run_ms)} + branch_start = time.perf_counter() + phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile) + total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0) + profile["bert_total_ms"] = max(0.0, total_ms - float(cpu_run_ms)) + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=total_ms, + cpu_preprocess_ms=float(cpu_run_ms), + ) + + async def _run_on_executor(self, executor, fn, *args) -> ProfiledResult: + loop = asyncio.get_running_loop() + submit_at = time.perf_counter() + return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) + + async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + executor = getattr(self.tts, "prepare_text_cpu_executor", None) + if executor is None: + submit_at = time.perf_counter() + return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) + return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + + async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult: + return await self._run_on_executor(self.text_feature_executor, self._build_text_features, prepared_segments, language, cpu_run_ms) + + @staticmethod + def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: + return float( + profile.get("bert_wait_ms", 0.0) + + profile.get("bert_tokenize_ms", 0.0) + + profile.get("bert_forward_ms", 0.0) + + profile.get("bert_scatter_ms", 0.0) + ) + + async def _run_text_feature_pair_stage( + self, + prompt_segments, + target_segments, + prompt_cpu_run_ms: float, + target_cpu_run_ms: float, + ) -> tuple[ProfiledResult, ProfiledResult]: + if self.text_feature_executor is not None: + prompt_feature_task = asyncio.create_task( + self._run_text_feature_stage(prompt_segments, None, prompt_cpu_run_ms) + ) + target_feature_task = asyncio.create_task( + self._run_text_feature_stage(target_segments, None, target_cpu_run_ms) + ) + return await asyncio.gather(prompt_feature_task, target_feature_task) + + prompt_profile: Dict[str, float] = {"cpu_preprocess_ms": float(prompt_cpu_run_ms)} + target_profile: Dict[str, float] = {"cpu_preprocess_ms": float(target_cpu_run_ms)} + submit_at = time.perf_counter() + started_at = float(submit_at) + prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + finished_at = time.perf_counter() + + prompt_result = PreparedTextFeatures( + phones=prompt_result_raw[0], + bert_features=prompt_result_raw[1], + norm_text=prompt_result_raw[2], + profile=prompt_profile, + total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), + cpu_preprocess_ms=float(prompt_cpu_run_ms), + ) + target_result = PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), + ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > prompt_profiled.finished_at: + prompt_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(prompt_profile), + (finished_at - submit_at) * 1000.0, + ) + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + async def _run_ref_audio_stage(self, ref_audio_path: str) -> ProfiledResult: + return await self._run_on_executor(self.ref_audio_executor, self.tts.extract_ref_audio_bundle, ref_audio_path) + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + admission_start = time.perf_counter() + if self._inflight_semaphore is not None: + await self._inflight_semaphore.acquire() + prepare_admission_wait_ms = max(0.0, (time.perf_counter() - admission_start) * 1000.0) + current_inflight, peak_inflight = self._mark_enter() + prepare_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + try: + text_pair_start = time.perf_counter() + prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang)) + target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang)) + ref_audio_task = asyncio.create_task(self._run_ref_audio_stage(str(spec.ref_audio_path))) + prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task) + text_feature_pair_task = asyncio.create_task( + self._run_text_feature_pair_stage( + prompt_cpu_profiled.result, + target_cpu_profiled.result, + prompt_cpu_profiled.run_ms, + target_cpu_profiled.run_ms, + ) + ) + (prompt_feature_profiled, target_feature_profiled), ref_audio_profiled = await asyncio.gather( + text_feature_pair_task, + ref_audio_task, + ) + text_pair_end = time.perf_counter() + state = build_request_state_from_parts( + tts=self.tts, + spec=spec, + prompt_text=prompt_text, + text=text, + prompt_result=prompt_feature_profiled.result, + target_result=target_feature_profiled.result, + ref_audio_bundle=ref_audio_profiled.result, + prepare_start=prepare_start, + prepare_sync_start=prepare_start, + profile_overrides={ + "executor_queue_ms": max(0.0, (prepare_start - prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": prepare_admission_wait_ms, + "executor_run_wall_ms": max(0.0, (time.perf_counter() - prepare_start) * 1000.0), + "text_feature_pair_ms": max(0.0, (text_pair_end - text_pair_start) * 1000.0), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": 0.0, + "prompt_text_parallel_future_finish_after_submit_ms": 0.0, + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "prompt_text_cpu_queue_ms": prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": prompt_cpu_profiled.run_ms, + "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, + "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, + "text_cpu_queue_ms": target_cpu_profiled.queue_ms, + "text_cpu_run_ms": target_cpu_profiled.run_ms, + "text_feature_queue_ms": target_feature_profiled.queue_ms, + "text_feature_run_ms": target_feature_profiled.run_ms, + "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, + "ref_audio_task_run_ms": ref_audio_profiled.run_ms, + "worker_prepare_inflight_on_enter": float(current_inflight), + "worker_prepare_peak_inflight": float(peak_inflight), + }, + ) + prepare_exec_finished_at = time.perf_counter() + state.prepare_profile["executor_run_wall_ms"] = max( + 0.0, (prepare_exec_finished_at - prepare_start) * 1000.0 + ) + return state, prepare_start, prepare_exec_finished_at + finally: + self._mark_leave() + if self._inflight_semaphore is not None: + self._inflight_semaphore.release() diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index b7118a72..8aabd286 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,6 +1,5 @@ from __future__ import annotations -from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path import time @@ -89,20 +88,31 @@ class T2SFinishedItem: class T2SActiveBatch: request_ids: List[str] states: List[T2SRequestState] - x: torch.Tensor - x_lens: torch.LongTensor + x: Optional[torch.Tensor] + x_lens: Optional[torch.LongTensor] y_sequences: List[torch.LongTensor] prefix_lens: torch.LongTensor xy_pos: torch.Tensor - key_padding_mask: torch.Tensor - prefill_attn_mask: torch.Tensor + key_padding_mask: Optional[torch.Tensor] + prefill_attn_mask: Optional[torch.Tensor] decode_attn_mask: Optional[torch.Tensor] k_cache: Optional[List[torch.Tensor]] v_cache: Optional[List[torch.Tensor]] - step_idx: int + kv_lens: Optional[torch.LongTensor] + step_indices: torch.LongTensor prefill_done: bool +@dataclass +class PreparedTextFeatures: + phones: List[int] + bert_features: torch.Tensor + norm_text: str + profile: Dict[str, float] + total_ms: float + cpu_preprocess_ms: float + + def normalize_sentence(text: str, language: str) -> str: text = text.strip("\n").strip() if not text: @@ -113,105 +123,125 @@ def normalize_sentence(text: str, language: str) -> str: @torch.inference_mode() -def prepare_request_state( +def prepare_text_features( + tts: Any, + text: str, + language: str, +) -> PreparedTextFeatures: + device = tts.configs.device + profile: Dict[str, float] = {} + branch_start = time.perf_counter() + _sync_device(device) + cpu_start = time.perf_counter() + prepared_segments = tts.prepare_text_segments(text, language) + _sync_device(device) + cpu_preprocess_ms = (time.perf_counter() - cpu_start) * 1000.0 + profile["cpu_preprocess_ms"] = float(cpu_preprocess_ms) + bert_start = time.perf_counter() + phones, bert_features, norm_text = tts.build_text_features_from_segments(prepared_segments, profile=profile) + _sync_device(device) + profile["bert_total_ms"] = (time.perf_counter() - bert_start) * 1000.0 + total_ms = (time.perf_counter() - branch_start) * 1000.0 + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=float(total_ms), + cpu_preprocess_ms=float(cpu_preprocess_ms), + ) + + +@torch.inference_mode() +def build_request_state_from_parts( tts: Any, spec: SchedulerRequestSpec, + prompt_text: str, + text: str, + prompt_result: PreparedTextFeatures, + target_result: PreparedTextFeatures, + ref_audio_bundle: Dict[str, Any], + prepare_start: float, + prepare_sync_start: float, + profile_overrides: Optional[Dict[str, float]] = None, ) -> T2SRequestState: device = tts.configs.device - prepare_start = time.perf_counter() _sync_device(device) - prepare_sync_start = time.perf_counter() - prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) - text = spec.text.strip("\n") - - prompt_text_profile: Dict[str, float] = {} - text_features_profile: Dict[str, float] = {} - text_feature_pair_start = time.perf_counter() - prompt_future: Future | None = None - - def _extract_prompt_features(): - _sync_device(device) - prompt_start = time.perf_counter() - result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile) - _sync_device(device) - return result, (time.perf_counter() - prompt_start) * 1000.0 - - if getattr(tts, "prepare_text_cpu_executor", None) is not None: - prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features) - - _sync_device(device) - text_features_start = time.perf_counter() - phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile) - _sync_device(device) - text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 - - if prompt_future is None: - _sync_device(device) - prompt_text_features_start = time.perf_counter() - prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features( - prompt_text, spec.prompt_lang, profile=prompt_text_profile - ) - _sync_device(device) - prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 - prompt_text_profile["parallel_future_wait_ms"] = 0.0 - else: - prompt_wait_start = time.perf_counter() - (prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result() - prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0 - - text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0 - if phones is None: - raise ValueError(f"{spec.request_id} text preprocessing returned no phones") - - _sync_device(device) - ref_audio_bundle_start = time.perf_counter() - ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) + bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic = ref_audio_bundle["prompt_semantic"].long() spec_audio, audio_16k = ref_audio_bundle["refer_spec"] raw_audio = ref_audio_bundle["raw_audio"] raw_sr = int(ref_audio_bundle["raw_sr"]) - _sync_device(device) - ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0 - bundle_profile = ref_audio_bundle.get("profile", {}) prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) _sync_device(device) tensorize_start = time.perf_counter() - phones_tensor = torch.LongTensor(phones).to(tts.configs.device) - prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device) - all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device) - all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to( + phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device) + all_bert_features = torch.cat([prompt_result.bert_features, target_result.bert_features], dim=1).to( dtype=tts.precision, device=tts.configs.device ) _sync_device(device) tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 - _sync_device(device) prepare_profile = { - "prompt_text_features_ms": prompt_text_features_ms, - "text_features_ms": text_features_ms, - "prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)), - "prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)), - "prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)), - "prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)), - "prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)), - "prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)), - "prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)), - "prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)), - "prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)), - "prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)), - "text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)), - "text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)), - "text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)), - "text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)), - "text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)), - "text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)), - "text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)), - "text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)), - "text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)), - "text_feature_pair_ms": text_feature_pair_ms, + "prompt_text_features_ms": float(prompt_result.total_ms), + "text_features_ms": float(target_result.total_ms), + "prompt_text_cpu_preprocess_ms": float(prompt_result.cpu_preprocess_ms), + "text_cpu_preprocess_ms": float(target_result.cpu_preprocess_ms), + "prompt_text_bert_wait_ms": float(prompt_result.profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_admission_wait_ms": float(prompt_result.profile.get("bert_admission_wait_ms", 0.0)), + "prompt_text_bert_queue_wait_ms": float(prompt_result.profile.get("bert_queue_wait_ms", 0.0)), + "prompt_text_bert_batch_collect_wait_ms": float(prompt_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_result.profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_result.profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_result.profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_result.profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_result.profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_result.profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_result.profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_result.profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_bert_pending_depth_on_enqueue_peak": float( + prompt_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "prompt_text_bert_pending_depth_on_collect_peak": float( + prompt_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "prompt_text_bert_high_pressure_mode_peak": float( + prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0) + ), + "prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_finish_after_submit_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "text_bert_wait_ms": float(target_result.profile.get("bert_wait_ms", 0.0)), + "text_bert_admission_wait_ms": float(target_result.profile.get("bert_admission_wait_ms", 0.0)), + "text_bert_queue_wait_ms": float(target_result.profile.get("bert_queue_wait_ms", 0.0)), + "text_bert_batch_collect_wait_ms": float(target_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "text_bert_forward_ms": float(target_result.profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(target_result.profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(target_result.profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(target_result.profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(target_result.profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(target_result.profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(target_result.profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(target_result.profile.get("bert_batch_tokens_peak", 0.0)), + "text_bert_pending_depth_on_enqueue_peak": float( + target_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "text_bert_pending_depth_on_collect_peak": float( + target_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)), + "text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)), + "text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)), "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), "audio_load_ms": audio_load_ms, "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), @@ -233,6 +263,8 @@ def prepare_request_state( "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, } + if profile_overrides: + prepare_profile.update({key: float(value) for key, value in profile_overrides.items()}) return T2SRequestState( request_id=spec.request_id, ref_audio_path=spec.ref_audio_path, @@ -240,8 +272,8 @@ def prepare_request_state( prompt_lang=spec.prompt_lang, text=text, text_lang=spec.text_lang, - norm_prompt_text=prompt_norm_text, - norm_text=norm_text, + norm_prompt_text=prompt_result.norm_text, + norm_text=target_result.norm_text, phones=phones_tensor, prompt_phones=prompt_phones_tensor, all_phones=all_phones, @@ -260,6 +292,33 @@ def prepare_request_state( ) +@torch.inference_mode() +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + prepare_start = time.perf_counter() + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang) + target_result = prepare_text_features(tts, text, spec.text_lang) + if target_result.phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + return build_request_state_from_parts( + tts=tts, + spec=spec, + prompt_text=prompt_text, + text=text, + prompt_result=prompt_result, + target_result=target_result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=prepare_start, + prepare_sync_start=prepare_sync_start, + ) + + def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: if hidden.shape[0] >= target_len: return hidden @@ -417,7 +476,8 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct decode_attn_mask=None, k_cache=None, v_cache=None, - step_idx=0, + kv_lens=None, + step_indices=torch.zeros((len(states),), dtype=torch.long, device=device), prefill_done=False, ) @@ -433,6 +493,64 @@ def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> to ) +def _compact_cache_to_kv_lens( + cache: torch.Tensor, + kv_lens: torch.LongTensor, +) -> torch.Tensor: + target_len = int(kv_lens.max().item()) + if cache.shape[1] == target_len and torch.all(kv_lens == target_len).item(): + return cache + compacted = cache.new_zeros((cache.shape[0], target_len, cache.shape[2])) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + if kv_len <= 0: + continue + compacted[batch_index, -kv_len:, :] = cache[batch_index, -kv_len:, :] + return compacted + + +def _compact_decode_mask_to_kv_lens( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + target_len = int(kv_lens.max().item()) + 1 + if decode_attn_mask is None: + return None + if decode_attn_mask.shape[-1] == target_len and torch.all(kv_lens + 1 == target_len).item(): + return decode_attn_mask + compacted = torch.ones( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + compacted[batch_index, :, :, -current_len:] = decode_attn_mask[batch_index, :, :, -current_len:] + if not compacted.any().item(): + return None + return compacted + + +def _advance_decode_mask( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + if decode_attn_mask is None: + return None + target_len = int(kv_lens.max().item()) + 2 + advanced = torch.zeros( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + next_mask = F.pad(decode_attn_mask[batch_index : batch_index + 1, :, :, -current_len:], (0, 1), value=False) + advanced[batch_index : batch_index + 1, :, :, -next_mask.shape[-1] :] = next_mask + if not advanced.any().item(): + return None + return advanced + + def _sample_per_request( model: Any, active_batch: T2SActiveBatch, @@ -443,16 +561,15 @@ def _sample_per_request( keep_indices: List[int] = [] updated_sequences: List[torch.LongTensor] = [] - step_idx = active_batch.step_idx sampling_keys = [ _sampling_group_key( top_k=state.top_k, top_p=state.top_p, temperature=state.temperature, repetition_penalty=state.repetition_penalty, - trim_eos=False, + trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, ) - for state in active_batch.states + for batch_index, state in enumerate(active_batch.states) ] sampled_items, argmax_tokens = _batched_sample_by_group( logits=logits, @@ -460,6 +577,7 @@ def _sample_per_request( sampling_keys=sampling_keys, ) for batch_index, state in enumerate(active_batch.states): + step_index = int(active_batch.step_indices[batch_index].item()) current_history = active_batch.y_sequences[batch_index] sampled = sampled_items[batch_index] sampled_token = int(sampled[0, 0].item()) @@ -469,7 +587,7 @@ def _sample_per_request( finish_reason: Optional[str] = None if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: finish_reason = "early_stop" - elif step_idx + 1 >= max_steps: + elif step_index + 1 >= max_steps: finish_reason = "max_step" elif sampled_token == model.EOS: finish_reason = "eos_sample" @@ -482,7 +600,7 @@ def _sample_per_request( T2SFinishedItem( request_id=state.request_id, semantic_tokens=new_history[prefix_len:-1].clone(), - finish_idx=step_idx, + finish_idx=step_index, finish_reason=finish_reason, ) ) @@ -493,30 +611,48 @@ def _sample_per_request( return finished_items, keep_indices, updated_sequences +@torch.inference_mode() def decode_one_step( model: Any, active_batch: T2SActiveBatch, max_steps: int, ) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: - if not active_batch.prefill_done: + was_prefill = not active_batch.prefill_done + if was_prefill: + if active_batch.prefill_attn_mask is None or active_batch.key_padding_mask is None: + raise ValueError("prefill 阶段缺少必要 mask") xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( active_batch.xy_pos, active_batch.prefill_attn_mask, None ) + active_batch.kv_lens = active_batch.x_lens + active_batch.prefix_lens active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("prefill 阶段未生成完整 KV cache") + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(active_batch.decode_attn_mask, active_batch.kv_lens) + active_batch.x = None + active_batch.x_lens = None + active_batch.key_padding_mask = None + active_batch.prefill_attn_mask = None active_batch.prefill_done = True else: + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("decode 阶段缺少 KV cache") + batched_decode_attn_mask = None + if active_batch.decode_attn_mask is not None: + batched_decode_attn_mask = _materialize_decode_mask_for_active_batch(active_batch) + if not batched_decode_attn_mask.any().item(): + batched_decode_attn_mask = None xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( active_batch.xy_pos, active_batch.k_cache, active_batch.v_cache, - active_batch.decode_attn_mask, + batched_decode_attn_mask, ) - if active_batch.decode_attn_mask is not None: - active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False) + active_batch.decode_attn_mask = _advance_decode_mask(active_batch.decode_attn_mask, active_batch.kv_lens) logits = model.ar_predict_layer(xy_dec[:, -1]) - if active_batch.step_idx < 11: - logits = logits[:, :-1] finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) if len(keep_indices) == 0: @@ -528,16 +664,32 @@ def decode_one_step( active_batch.states = [active_batch.states[i] for i in keep_indices] active_batch.y_sequences = updated_sequences active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + next_step_indices = torch.index_select(active_batch.step_indices, dim=0, index=keep_tensor) + next_kv_lens = None if active_batch.kv_lens is None else torch.index_select(active_batch.kv_lens, dim=0, index=keep_tensor) + active_batch.step_indices = next_step_indices + 1 + if not was_prefill: + if next_kv_lens is not None: + active_batch.kv_lens = next_kv_lens + 1 + else: + active_batch.kv_lens = next_kv_lens if active_batch.decode_attn_mask is not None: active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if not active_batch.decode_attn_mask.any().item(): + active_batch.decode_attn_mask = None if active_batch.k_cache is not None and active_batch.v_cache is not None: for cache_index in range(len(active_batch.k_cache)): active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + if active_batch.kv_lens is not None: + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens( + active_batch.decode_attn_mask, + active_batch.kv_lens, + ) active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) - active_batch.step_idx += 1 return active_batch, finished_items @@ -583,6 +735,126 @@ def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> ) +def _materialize_decode_mask_for_active_batch( + active_batch: T2SActiveBatch, + target_mask_len: Optional[int] = None, +) -> torch.Tensor: + if active_batch.k_cache is None or active_batch.kv_lens is None: + raise ValueError("active batch 缺少 KV cache 或 kv_lens") + current_mask_len = active_batch.k_cache[0].shape[1] + 1 + if target_mask_len is None: + target_mask_len = current_mask_len + if active_batch.decode_attn_mask is None: + mask = torch.zeros( + (len(active_batch.request_ids), 1, 1, current_mask_len), + dtype=torch.bool, + device=active_batch.k_cache[0].device, + ) + else: + rows: List[torch.Tensor] = [] + for batch_index, kv_len in enumerate(active_batch.kv_lens.tolist()): + row_len = kv_len + 1 + row_mask = _fit_decode_mask_length( + active_batch.decode_attn_mask[batch_index : batch_index + 1], + row_len, + ) + rows.append(_pad_decode_mask_left(row_mask, target_mask_len)) + mask = torch.cat(rows, dim=0) + if target_mask_len != current_mask_len and active_batch.decode_attn_mask is None: + mask = _pad_decode_mask_left(mask, target_mask_len) + return mask + + +@torch.inference_mode() +def run_prefill_active_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not states: + return None, [] + active_batch = build_prefill_batch(model, states) + return decode_one_step(model, active_batch, max_steps=max_steps) + + +@torch.inference_mode() +def merge_active_batches( + model: Any, + left_batch: Optional[T2SActiveBatch], + right_batch: Optional[T2SActiveBatch], +) -> Optional[T2SActiveBatch]: + if left_batch is None: + return right_batch + if right_batch is None: + return left_batch + if not left_batch.prefill_done or not right_batch.prefill_done: + raise ValueError("只有 prefill 完成后的 active batch 才能 merge") + if left_batch.k_cache is None or left_batch.v_cache is None or right_batch.k_cache is None or right_batch.v_cache is None: + raise ValueError("merge active batch 时缺少 KV cache") + + left_kv_len = int(left_batch.k_cache[0].shape[1]) + right_kv_len = int(right_batch.k_cache[0].shape[1]) + merged_kv_len = max(left_kv_len, right_kv_len) + merged_mask_len = merged_kv_len + 1 + + merged_k_cache: List[torch.Tensor] = [] + merged_v_cache: List[torch.Tensor] = [] + for layer_index in range(len(left_batch.k_cache)): + merged_k_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.k_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.k_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + merged_v_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.v_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.v_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + + merged_decode_attn_mask = torch.cat( + [ + _materialize_decode_mask_for_active_batch(left_batch, merged_mask_len), + _materialize_decode_mask_for_active_batch(right_batch, merged_mask_len), + ], + dim=0, + ) + merged_request_ids = list(left_batch.request_ids) + list(right_batch.request_ids) + merged_states = list(left_batch.states) + list(right_batch.states) + merged_y_sequences = list(left_batch.y_sequences) + list(right_batch.y_sequences) + merged_prefix_lens = torch.cat([left_batch.prefix_lens, right_batch.prefix_lens], dim=0) + if left_batch.kv_lens is None or right_batch.kv_lens is None: + raise ValueError("merge active batch 时缺少 kv_lens") + merged_kv_lens = torch.cat([left_batch.kv_lens, right_batch.kv_lens], dim=0) + merged_decode_attn_mask = _compact_decode_mask_to_kv_lens(merged_decode_attn_mask, merged_kv_lens) + merged_step_indices = torch.cat([left_batch.step_indices, right_batch.step_indices], dim=0) + + return T2SActiveBatch( + request_ids=merged_request_ids, + states=merged_states, + x=None, + x_lens=None, + y_sequences=merged_y_sequences, + prefix_lens=merged_prefix_lens, + xy_pos=build_next_xy_pos(model, merged_y_sequences), + key_padding_mask=None, + prefill_attn_mask=None, + decode_attn_mask=merged_decode_attn_mask, + k_cache=merged_k_cache, + v_cache=merged_v_cache, + kv_lens=merged_kv_lens, + step_indices=merged_step_indices, + prefill_done=True, + ) + + @torch.inference_mode() def run_prefill_step( model: Any, @@ -804,29 +1076,24 @@ def run_scheduler_continuous( max_steps: int, ) -> List[T2SFinishedItem]: pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) - running_requests: List[T2SRunningRequest] = [] + active_batch: Optional[T2SActiveBatch] = None finished: List[T2SFinishedItem] = [] current_tick = 0 - while pending or running_requests: + while pending or active_batch is not None: admitted: List[T2SRequestState] = [] while pending and pending[0].ready_step <= current_tick: admitted.append(pending.pop(0)) - admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps) + admitted_active_batch, admitted_finished = run_prefill_active_batch(model, admitted, max_steps=max_steps) finished.extend(admitted_finished) + active_batch = merge_active_batches(model, active_batch, admitted_active_batch) - if running_requests: - running_requests, step_finished = run_decode_step_for_running( - model, - running_requests, - max_steps=max_steps, - ) + if active_batch is not None: + active_batch, step_finished = decode_one_step(model, active_batch, max_steps=max_steps) finished.extend(step_finished) - running_requests.extend(admitted_running) - - if not running_requests and pending: + if active_batch is None and pending: current_tick = max(current_tick + 1, pending[0].ready_step) continue diff --git a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py new file mode 100644 index 00000000..e2398251 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py @@ -0,0 +1,100 @@ +import os +import re +import sys +from typing import Dict, List, Optional, Tuple + +now_dir = os.getcwd() +sys.path.append(now_dir) + +from text.LangSegmenter import LangSegmenter +from text import cleaned_text_to_sequence +from text.cleaner import clean_text + + +PreparedTextSegmentPayload = Dict[str, object] + + +def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]: + textlist: List[str] = [] + langlist: List[str] = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist + + +def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]: + normalized_language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, normalized_language, version) + phones = cleaned_text_to_sequence(phones, version) + return list(phones), None if word2ph is None else list(word2ph), str(norm_text) + + +def preprocess_text_segments_payload( + text: str, + language: str, + version: str, + final: bool = False, +) -> List[PreparedTextSegmentPayload]: + text = re.sub(r" {2,}", " ", text) + textlist, langlist = split_text_by_language(text, language) + payloads: List[PreparedTextSegmentPayload] = [] + total_phones_len = 0 + for segment_text, segment_lang in zip(textlist, langlist): + phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + payloads.append( + { + "language": segment_lang.replace("all_", ""), + "phones": phones, + "word2ph": word2ph, + "norm_text": norm_text, + } + ) + total_phones_len += len(phones) + + if not final and total_phones_len < 6: + return preprocess_text_segments_payload("." + text, language, version, final=True) + + return payloads diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 348ddb3f..c6d147cf 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -2,6 +2,7 @@ import warnings warnings.filterwarnings("ignore") import math +from typing import List import torch from torch import nn @@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + @torch.no_grad() + def decode_batched_request_local( + self, + codes: torch.Tensor, + code_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + refer_list: List[torch.Tensor], + noise_scale: float = 0.5, + speed: float = 1, + sv_emb: torch.Tensor | None = None, + ): + batch_size = int(codes.size(1)) + if batch_size <= 0: + raise ValueError("decode_batched_request_local 收到空 batch") + if len(refer_list) != batch_size: + raise ValueError("refer_list 数量与 batch size 不一致") + + refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device) + max_refer_len = int(refer_lengths.max().item()) + refer_batch = torch.zeros( + (batch_size, int(refer_list[0].size(1)), max_refer_len), + dtype=refer_list[0].dtype, + device=codes.device, + ) + for batch_index, refer in enumerate(refer_list): + refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype) + if self.version == "v1": + ge = self.ref_enc(refer_batch * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + if sv_emb is None: + raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb") + ge = ge + self.sv_emb(sv_emb).unsqueeze(-1) + ge = self.prelu(ge) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") + y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2 + text_lengths = text_lengths.to(device=text.device, dtype=torch.long) + x, m_p, logs_p, y_mask, _, _ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=ge, reverse=True) + audio = self.dec((z * y_mask)[:, :, :], g=ge) + upsample_factor = 1 + for up_layer in self.dec.ups: + stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride) + upsample_factor *= int(stride) + audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor) + return audio, audio_lengths + @torch.no_grad() def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): diff --git a/api_v3.py b/api_v3.py index 92f9a3b9..74bc7ac8 100644 --- a/api_v3.py +++ b/api_v3.py @@ -107,6 +107,7 @@ import sys import time import traceback import uuid +from collections import deque from dataclasses import dataclass from pathlib import Path from typing import Generator, List, Union @@ -115,6 +116,10 @@ now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + import argparse import subprocess import wave @@ -128,14 +133,15 @@ import uvicorn from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( SchedulerRequestSpec, + T2SActiveBatch, T2SFinishedItem, - T2SRunningRequest, T2SRequestState, - prepare_request_state, - run_decode_step_for_running, - run_prefill_step, + merge_active_batches, + decode_one_step, + run_prefill_active_batch, run_scheduler_continuous, ) from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names @@ -238,39 +244,71 @@ class SchedulerPendingJob: request_id: str state: T2SRequestState done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None enqueue_time: float speed_factor: float sample_steps: int media_type: str - prepare_ms: float = 0.0 prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 first_schedule_time: float | None = None prefill_ms: float = 0.0 + merge_ms: float = 0.0 decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 synth_ms: float = 0.0 pack_ms: float = 0.0 decode_steps: int = 0 + result_ready_time: float | None = None result: dict | None = None sample_rate: int | None = None audio_data: np.ndarray | None = None error: str | None = None +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + class SchedulerDebugWorker: def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): self.tts = tts self.max_steps = max_steps self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 + self.prepare_coordinator = PrepareCoordinator(tts) self.condition = threading.Condition() self.prepare_inflight = 0 self.prepare_peak_inflight = 0 + self.finalize_condition = threading.Condition() + self.finalize_pending_tasks: deque[SchedulerFinalizeTask] = deque() + self.finalize_pending_peak = 0 + self.finalize_inflight = 0 + self.finalize_inflight_peak = 0 + self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) self.pending_jobs: List[SchedulerPendingJob] = [] - self.running_requests: List[T2SRunningRequest] = [] + self.active_batch: T2SActiveBatch | None = None self.job_map: dict[str, SchedulerPendingJob] = {} self.total_finished = 0 self.total_submitted = 0 self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) self.worker_thread.start() + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"t2s-scheduler-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_workers) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() def _sync_device(self) -> None: try: @@ -283,20 +321,7 @@ class SchedulerDebugWorker: pass def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - with self.condition: - self.prepare_inflight += 1 - prepare_inflight_on_enter = self.prepare_inflight - if self.prepare_inflight > self.prepare_peak_inflight: - self.prepare_peak_inflight = self.prepare_inflight - prepare_peak_inflight = self.prepare_peak_inflight - try: - state = prepare_request_state(self.tts, spec) - state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter) - state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight) - return state - finally: - with self.condition: - self.prepare_inflight = max(0, self.prepare_inflight - 1) + raise RuntimeError("prepare_state sync path has been replaced by PrepareCoordinator") def submit( self, @@ -304,27 +329,47 @@ class SchedulerDebugWorker: speed_factor: float, sample_steps: int, media_type: str, - prepare_ms: float, prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, ) -> SchedulerPendingJob: job = SchedulerPendingJob( request_id=state.request_id, state=state, done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, enqueue_time=time.perf_counter(), speed_factor=float(speed_factor), sample_steps=int(sample_steps), media_type=media_type, - prepare_ms=float(prepare_ms), prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), ) with self.condition: self.pending_jobs.append(job) self.job_map[job.request_id] = job self.total_submitted += 1 self.condition.notify_all() + with self.finalize_condition: + self.finalize_condition.notify_all() return job + async def prepare_state_async(self, spec: SchedulerRequestSpec) -> T2SRequestState: + state, _, _ = await self.prepare_coordinator.prepare_state_profiled_async(spec, time.perf_counter()) + return state + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await asyncio.gather(*[self.prepare_state_async(spec) for spec in specs]) + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: with self.condition: for job in jobs: @@ -340,6 +385,14 @@ class SchedulerDebugWorker: if tracked_job is not None: tracked_job.prefill_ms += elapsed_ms + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.merge_ms += elapsed_ms + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: elapsed_ms = elapsed_s * 1000.0 with self.condition: @@ -349,16 +402,30 @@ class SchedulerDebugWorker: job.decode_ms += elapsed_ms job.decode_steps += 1 + def _add_finalize_wait_ms(self, request_ids: List[str], elapsed_ms: float) -> None: + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.finalize_wait_ms += elapsed_ms + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: - semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) - phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device) + semantic_tokens = item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) + phones = job.state.phones.detach().clone().unsqueeze(0).to(self.tts.configs.device) + prompt_semantic = job.state.prompt_semantic.detach().clone() + prompt_phones = job.state.prompt_phones.detach().clone() + refer_spec = ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + raw_audio = job.state.raw_audio.detach().clone() audio_fragment = self.tts.synthesize_audio_request_local( semantic_tokens=semantic_tokens, phones=phones, - prompt_semantic=job.state.prompt_semantic, - prompt_phones=job.state.prompt_phones, - refer_spec=job.state.refer_spec, - raw_audio=job.state.raw_audio, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_spec=refer_spec, + raw_audio=raw_audio, raw_sr=job.state.raw_sr, speed=float(job.speed_factor), sample_steps=int(job.sample_steps), @@ -375,6 +442,11 @@ class SchedulerDebugWorker: ) def get_state(self) -> dict: + with self.finalize_condition: + finalize_pending = len(self.finalize_pending_tasks) + finalize_pending_peak = self.finalize_pending_peak + finalize_inflight = self.finalize_inflight + finalize_inflight_peak = self.finalize_inflight_peak with self.condition: bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() @@ -388,12 +460,24 @@ class SchedulerDebugWorker: if self.tts.prepare_ref_semantic_batch_worker is None else self.tts.prepare_ref_semantic_batch_worker.snapshot() ) + prepare_coordinator_state = self.prepare_coordinator.snapshot() return { "pending_jobs": len(self.pending_jobs), - "running_requests": len(self.running_requests), - "prepare_inflight": self.prepare_inflight, - "prepare_peak_inflight": self.prepare_peak_inflight, + "running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids), + "prepare_inflight": prepare_coordinator_state["inflight"], + "prepare_peak_inflight": prepare_coordinator_state["peak_inflight"], + "finalize_pending": finalize_pending, + "finalize_pending_peak": finalize_pending_peak, + "finalize_inflight": finalize_inflight, + "finalize_inflight_peak": finalize_inflight_peak, + "finalize_workers": self.finalize_workers, + "finalize_mode": self.finalize_mode, + "finalize_batch_max_items": self.finalize_batch_max_items, + "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, + "prepare_request_executor_workers": 0, "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), + "prepare_text_feature_workers": int(prepare_coordinator_state["text_feature_workers"]), + "prepare_ref_audio_workers": int(prepare_coordinator_state["ref_audio_workers"]), "prepare_bert_stage": bert_stage, "prepare_bert_batch_worker": bert_batch_worker, "prepare_ref_audio_stage": ref_audio_stage, @@ -405,59 +489,217 @@ class SchedulerDebugWorker: "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), } - def _finalize_finished(self, items: List[T2SFinishedItem]) -> None: + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: if not items: return - jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + tasks: List[SchedulerFinalizeTask] = [] + enqueued_time = time.perf_counter() with self.condition: for item in items: job = self.job_map.get(item.request_id) if job is not None: - jobs_to_finalize.append((job, item)) + tasks.append( + SchedulerFinalizeTask( + request_id=item.request_id, + item=item, + enqueued_time=enqueued_time, + ) + ) + if not tasks: + return + with self.finalize_condition: + self.finalize_pending_tasks.extend(tasks) + if len(self.finalize_pending_tasks) > self.finalize_pending_peak: + self.finalize_pending_peak = len(self.finalize_pending_tasks) + self.finalize_condition.notify_all() - for job, item in jobs_to_finalize: + @staticmethod + def _finalize_batch_key(job: SchedulerPendingJob) -> tuple[float, int]: + return (round(float(job.speed_factor), 6), int(job.sample_steps)) + + def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: + with self.finalize_condition: + while not self.finalize_pending_tasks: + self.finalize_condition.wait() + if self.finalize_mode == "after_t2s_drain": + while not self._is_t2s_drained(): + self.finalize_condition.wait(timeout=self.micro_batch_wait_s) + task = self.finalize_pending_tasks.popleft() + selected_tasks = [task] + batch_key = None + with self.condition: + first_job = self.job_map.get(task.request_id) + if first_job is not None: + batch_key = self._finalize_batch_key(first_job) + batch_deadline = time.perf_counter() + self.finalize_batch_wait_s + while len(selected_tasks) < self.finalize_batch_max_items: + if batch_key is None: + break + matched_index = None + for pending_index, pending_task in enumerate(self.finalize_pending_tasks): + with self.condition: + pending_job = self.job_map.get(pending_task.request_id) + if pending_job is None: + matched_index = pending_index + break + if self._finalize_batch_key(pending_job) == batch_key: + matched_index = pending_index + break + if matched_index is not None: + selected_tasks.append(self.finalize_pending_tasks[matched_index]) + del self.finalize_pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.finalize_condition.wait(timeout=remaining) + self.finalize_inflight += len(selected_tasks) + if self.finalize_inflight > self.finalize_inflight_peak: + self.finalize_inflight_peak = self.finalize_inflight + return selected_tasks + + def _finalize_task_done(self, count: int) -> None: + with self.finalize_condition: + self.finalize_inflight = max(0, self.finalize_inflight - count) + + def _is_t2s_drained(self) -> bool: + with self.condition: + return ( + self.active_batch is None + and not self.pending_jobs + and self.prepare_inflight <= 0 + ) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + return + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self._notify_done_future(job) + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def _run_finalize_loop(self) -> None: + while True: + tasks = self._take_finalize_task_batch() try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + finalize_wait_request_ids: List[str] = [] + with self.condition: + for task in tasks: + job = self.job_map.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + finalize_wait_request_ids.append(task.request_id) + if not jobs_and_items: + continue + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) self._sync_device() synth_start = time.perf_counter() - sample_rate, audio_data = self._synthesize_finished_audio(job, item) + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) self._sync_device() synth_ms = (time.perf_counter() - synth_start) * 1000.0 + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) except Exception as exc: - self._finalize_error([item.request_id], str(exc)) - continue - - finished_at = time.perf_counter() - with self.condition: - if self.job_map.get(item.request_id) is not job: - continue - queue_wait_ms = 0.0 - if job.first_schedule_time is not None: - queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) - worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) - job.synth_ms += synth_ms - job.sample_rate = int(sample_rate) - job.audio_data = audio_data - prepare_profile = dict(job.state.prepare_profile) - job.result = { - "request_id": item.request_id, - "semantic_len": int(item.semantic_tokens.shape[0]), - "finish_idx": int(item.finish_idx), - "finish_reason": item.finish_reason, - "prepare_ms": job.prepare_ms, - "prepare_wall_ms": job.prepare_wall_ms, - "prepare_profile": prepare_profile, - "queue_wait_ms": queue_wait_ms, - "prefill_ms": job.prefill_ms, - "decode_ms": job.decode_ms, - "synth_ms": job.synth_ms, - "worker_total_ms": worker_total_ms, - "decode_steps": int(job.decode_steps), - "sample_rate": int(sample_rate), - "media_type": job.media_type, - } - job.done_event.set() - self.job_map.pop(item.request_id, None) - self.total_finished += 1 + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self._finalize_task_done(len(tasks)) def _finalize_error(self, request_ids: List[str], error: str) -> None: if not request_ids: @@ -469,12 +711,28 @@ class SchedulerDebugWorker: continue job.error = error job.done_event.set() + self._notify_done_future(job) self.job_map.pop(request_id, None) self.total_finished += 1 + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(True) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: with self.condition: - if not self.pending_jobs and not self.running_requests: + if not self.pending_jobs and self.active_batch is None: self.condition.wait(timeout=self.micro_batch_wait_s) elif wait_for_batch and self.pending_jobs: self.condition.wait(timeout=self.micro_batch_wait_s) @@ -482,11 +740,13 @@ class SchedulerDebugWorker: return [] pending = list(self.pending_jobs) self.pending_jobs.clear() + with self.finalize_condition: + self.finalize_condition.notify_all() return pending def _run_loop(self) -> None: while True: - wait_for_batch = len(self.running_requests) == 0 + wait_for_batch = self.active_batch is None pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) if pending_jobs: @@ -494,37 +754,54 @@ class SchedulerDebugWorker: self._sync_device() prefill_start = time.perf_counter() self._mark_prefill_started(pending_jobs, prefill_start) - admitted_running, admitted_finished = run_prefill_step( + admitted_active_batch, admitted_finished = run_prefill_active_batch( self.tts.t2s_model.model, [job.state for job in pending_jobs], max_steps=self.max_steps, ) self._sync_device() self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) - self._finalize_finished(admitted_finished) - self.running_requests.extend(admitted_running) + self._enqueue_finalize_finished(admitted_finished) + merge_start = time.perf_counter() + self.active_batch = merge_active_batches( + self.tts.t2s_model.model, + self.active_batch, + admitted_active_batch, + ) + self._add_merge_time( + [] if self.active_batch is None else list(self.active_batch.request_ids), + time.perf_counter() - merge_start, + ) + with self.finalize_condition: + self.finalize_condition.notify_all() except Exception as exc: self._finalize_error([job.request_id for job in pending_jobs], str(exc)) - if self.running_requests: + if self.active_batch is not None: try: - active_request_ids = [item.state.request_id for item in self.running_requests] + active_request_ids = [state.request_id for state in self.active_batch.states] self._sync_device() decode_start = time.perf_counter() - self.running_requests, step_finished = run_decode_step_for_running( + self.active_batch, step_finished = decode_one_step( self.tts.t2s_model.model, - self.running_requests, + self.active_batch, max_steps=self.max_steps, ) self._sync_device() self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) - self._finalize_finished(step_finished) + self._enqueue_finalize_finished(step_finished) + with self.finalize_condition: + self.finalize_condition.notify_all() except Exception as exc: self._finalize_error(active_request_ids, str(exc)) - self.running_requests = [] + self.active_batch = None + with self.finalize_condition: + self.finalize_condition.notify_all() continue if not pending_jobs: + with self.finalize_condition: + self.finalize_condition.notify_all() time.sleep(self.micro_batch_wait_s) @@ -788,10 +1065,6 @@ def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: ] -def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: - return [scheduler_debug_worker.prepare_state(spec) for spec in specs] - - def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: payload = request.dict() request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" @@ -845,7 +1118,7 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): try: set_scheduler_seed(request.seed) specs = build_scheduler_request_specs(request.requests) - states = await asyncio.to_thread(prepare_scheduler_states_batch, specs) + states = await scheduler_debug_worker.prepare_states_batch_async(specs) finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) return JSONResponse( status_code=200, @@ -867,20 +1140,51 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): try: request_start = time.perf_counter() + prepare_start = request_start spec = build_scheduler_submit_spec(request) - prepare_start = time.perf_counter() - state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec) - prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0 - prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async( + spec, + spec_ready_at, + ) + prepare_end = time.perf_counter() + prepare_wall_ms = (prepare_end - prepare_start) * 1000.0 + prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms)) + prepare_executor_queue_ms = float( + state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)) + ) + prepare_executor_run_ms = float( + state.prepare_profile.get( + "executor_run_wall_ms", + max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0), + ) + ) + prepare_other_ms = max( + 0.0, + prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() job = scheduler_debug_worker.submit( state, speed_factor=float(request.speed_factor), sample_steps=int(request.sample_steps), media_type=request.media_type, - prepare_ms=prepare_ms, prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, ) - timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec)) + api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0) + timeout_ok = False + try: + await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec)) + timeout_ok = True + except asyncio.TimeoutError: + timeout_ok = False + wait_return_at = time.perf_counter() if not timeout_ok: return JSONResponse( status_code=202, @@ -888,8 +1192,10 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "message": "queued", "request_id": job.request_id, "timings": { - "prepare_ms": prepare_ms, + "prepare_ms": prepare_wall_ms, "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "api_after_prepare_ms": api_after_prepare_ms, "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), }, "worker_state": scheduler_debug_worker.get_state(), @@ -911,9 +1217,13 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): ) pack_start = time.perf_counter() audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_ms = (time.perf_counter() - pack_start) * 1000.0 + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 job.pack_ms = pack_ms - request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 headers = { "X-Request-Id": job.request_id, "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", @@ -921,16 +1231,32 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "X-Queue-Wait-Ms": ( f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" ), - "X-Prepare-Ms": f"{prepare_ms:.3f}", + "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", + "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", + "X-Prepare-Admission-Wait-Ms": ( + f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" + if job.result is not None + else "0.000" + ), + "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", + "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", + "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", + "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", + "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", "X-Pack-Ms": f"{pack_ms:.3f}", "X-Worker-Total-Ms": ( f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" ), - "X-Request-Total-Ms": f"{request_total_ms:.3f}", + "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", } if job.result is not None: @@ -939,16 +1265,48 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): { "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str( + int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str( + int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0)) + ), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str( + int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str( + int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0)) + ), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str( + int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0)) + ), + "X-Prepare-Target-Bert-High-Pressure-Peak": str( + int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0)) + ), "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) ), "X-Prepare-Target-Bert-Batch-Size-Peak": str( int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) ), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", @@ -964,13 +1322,22 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", - "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", "X-Prepare-Inflight-On-Enter": str( int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) ), "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), } ) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + ) + headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" + headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" + headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) except Exception as e: return JSONResponse( From 69ac7f90271bc3099fe608cabbf9979077a68a71 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Tue, 10 Mar 2026 06:59:28 +0800 Subject: [PATCH 9/9] Integrate UnifiedTTSEngine into TTS API for improved audio processing and control. Refactor tts_handle and control endpoints to utilize the new engine, enhancing error handling and response management. Update set_refer_audio and set_gpt_weights endpoints to return payloads from the engine, streamlining audio configuration processes. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 36 +- GPT_SoVITS/TTS_infer_pack/unified_engine.py | 1255 +++++++++++++++++++ api_v2.py | 99 +- api_v3.py | 323 +---- 4 files changed, 1332 insertions(+), 381 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/unified_engine.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index d475b804..c7ae465c 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -468,7 +468,26 @@ class TTS: ) self._init_models() + self.refresh_runtime_components() + self.prompt_cache: dict = { + "ref_audio_path": None, + "prompt_semantic": None, + "refer_spec": [], + "prompt_text": None, + "prompt_lang": None, + "phones": None, + "bert_features": None, + "norm_text": None, + "aux_ref_audio_paths": [], + } + + self.stop_flag: bool = False + self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 + + def refresh_runtime_components(self): + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": self.prepare_bert_batch_worker = PrepareBertBatchWorker( bert_model=self.bert_model, @@ -509,7 +528,7 @@ class TTS: max_batch_samples=int(ref_max_batch_samples), ) - self.text_preprocessor: TextPreprocessor = TextPreprocessor( + self.text_preprocessor = TextPreprocessor( self.bert_model, self.bert_tokenizer, self.configs.device, @@ -517,21 +536,6 @@ class TTS: bert_batch_worker=self.prepare_bert_batch_worker, ) - self.prompt_cache: dict = { - "ref_audio_path": None, - "prompt_semantic": None, - "refer_spec": [], - "prompt_text": None, - "prompt_lang": None, - "phones": None, - "bert_features": None, - "norm_text": None, - "aux_ref_audio_paths": [], - } - - self.stop_flag: bool = False - self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 - def _init_models( self, ): diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py new file mode 100644 index 00000000..0a95015f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -0,0 +1,1255 @@ +from __future__ import annotations + +import asyncio +import os +import signal +import subprocess +import sys +import threading +import time +import uuid +import wave +from collections import deque +from dataclasses import dataclass, field +from io import BytesIO +from pathlib import Path +from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Sequence, Tuple, Union + +import numpy as np +import soundfile as sf +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + SchedulerRequestSpec, + T2SActiveBatch, + T2SFinishedItem, + T2SRequestState, + decode_one_step, + merge_active_batches, + run_prefill_active_batch, + run_scheduler_continuous, +) + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +class UnifiedSchedulerWorker: + def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): + self.tts = tts + self.max_steps = int(max_steps) + self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.prepare_coordinator = PrepareCoordinator(tts) + self.condition = threading.Condition() + self.prepare_inflight = 0 + self.prepare_peak_inflight = 0 + self.finalize_condition = threading.Condition() + self.finalize_pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.finalize_pending_peak = 0 + self.finalize_inflight = 0 + self.finalize_inflight_peak = 0 + self.finalize_workers = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.finalize_batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.finalize_batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + self.job_map: Dict[str, SchedulerPendingJob] = {} + self.total_finished = 0 + self.total_submitted = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) + self.worker_thread.start() + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"unified-t2s-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_workers) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() + + def snapshot(self) -> dict: + with self.condition: + finalize_pending = len(self.finalize_pending_tasks) + prepare_state = self.prepare_coordinator.snapshot() + return { + "pending_jobs": len(self.pending_jobs), + "running_requests": 0 if self.active_batch is None else len(self.active_batch.request_ids), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "finalize_pending": finalize_pending, + "finalize_pending_peak": self.finalize_pending_peak, + "finalize_inflight": self.finalize_inflight, + "finalize_inflight_peak": self.finalize_inflight_peak, + "finalize_workers": self.finalize_workers, + "finalize_mode": self.finalize_mode, + "finalize_batch_max_items": self.finalize_batch_max_items, + "finalize_batch_wait_ms": self.finalize_batch_wait_s * 1000.0, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + with self.finalize_condition: + return ( + self.active_batch is None + and not self.pending_jobs + and not self.job_map + and self.prepare_coordinator.snapshot()["inflight"] <= 0 + and self.finalize_inflight <= 0 + and not self.finalize_pending_tasks + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + ) -> SchedulerPendingJob: + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + ) + with self.condition: + self.pending_jobs.append(job) + self.job_map[job.request_id] = job + self.total_submitted += 1 + self.condition.notify_all() + with self.finalize_condition: + self.finalize_condition.notify_all() + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + with self.condition: + self.prepare_inflight += 1 + self.prepare_peak_inflight = max(self.prepare_peak_inflight, self.prepare_inflight) + try: + return await self.prepare_coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + with self.condition: + self.prepare_inflight = max(0, self.prepare_inflight - 1) + self.condition.notify_all() + with self.finalize_condition: + self.finalize_condition.notify_all() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is None: + continue + tracked_job.first_schedule_time = float(started_at) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.decode_ms += delta_ms + job.decode_steps += 1 + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + with self.finalize_condition: + for item in items: + self.finalize_pending_tasks.append( + SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) + ) + self.finalize_pending_peak = max(self.finalize_pending_peak, len(self.finalize_pending_tasks)) + self.finalize_condition.notify_all() + + def _take_finalize_task_batch(self) -> List[SchedulerFinalizeTask]: + with self.finalize_condition: + while not self.finalize_pending_tasks: + self.finalize_condition.wait() + selected_tasks = [self.finalize_pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.finalize_inflight += len(selected_tasks) + self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) + return selected_tasks + batch_deadline = time.perf_counter() + self.finalize_batch_wait_s + while len(selected_tasks) < self.finalize_batch_max_items: + if not self.finalize_pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.finalize_condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.finalize_pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.finalize_pending_tasks[matched_index]) + del self.finalize_pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.finalize_condition.wait(timeout=remaining) + self.finalize_inflight += len(selected_tasks) + self.finalize_inflight_peak = max(self.finalize_inflight_peak, self.finalize_inflight) + return selected_tasks + + def _finalize_task_done(self, count: int) -> None: + with self.finalize_condition: + self.finalize_inflight = max(0, self.finalize_inflight - count) + self.finalize_condition.notify_all() + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ), + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_specs.append( + ( + job.state.refer_spec[0].detach().clone(), + None if job.state.refer_spec[1] is None else job.state.refer_spec[1].detach().clone(), + ) + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + ) + return results + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + return + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self._notify_done_future(job) + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + self.condition.notify_all() + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is None: + continue + job.error = error + job.done_event.set() + self._notify_done_future(job) + self.job_map.pop(request_id, None) + self.total_finished += 1 + self.condition.notify_all() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(True) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def _run_finalize_loop(self) -> None: + while True: + tasks = self._take_finalize_task_batch() + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_map.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + continue + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + self._sync_device() + synth_start = time.perf_counter() + if len(jobs_and_items) == 1 or self.tts.configs.use_vocoder: + job, item = jobs_and_items[0] + batch_results = [self._synthesize_finished_audio(job, item)] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self._finalize_task_done(len(tasks)) + + def _run_loop(self) -> None: + while True: + wait_for_batch = self.active_batch is None + pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) + + if pending_jobs: + try: + self._sync_device() + prefill_start = time.perf_counter() + self._mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + self._add_prefill_time([job.request_id for job in pending_jobs], time.perf_counter() - prefill_start) + self._enqueue_finalize_finished(admitted_finished) + merge_start = time.perf_counter() + self.active_batch = merge_active_batches( + self.tts.t2s_model.model, + self.active_batch, + admitted_active_batch, + ) + self._add_merge_time( + [] if self.active_batch is None else list(self.active_batch.request_ids), + time.perf_counter() - merge_start, + ) + except Exception as exc: + self._finalize_error([job.request_id for job in pending_jobs], str(exc)) + + if self.active_batch is not None: + active_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in self.active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + self.active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + self.active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) + self._enqueue_finalize_finished(step_finished) + except Exception as exc: + self._finalize_error(active_request_ids, str(exc)) + self.active_batch = None + continue + + if not pending_jobs: + time.sleep(self.micro_batch_wait_s) + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except (RuntimeError, ValueError): + handle_pack_ogg() + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", + "-ar", + str(rate), + "-ac", + "1", + "-i", + "pipe:0", + "-c:a", + "aac", + "-b:a", + "192k", + "-vn", + "-f", + "adts", + "pipe:1", + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + wav_buf.seek(0) + return wav_buf.read() + + +class UnifiedTTSEngine: + def __init__( + self, + tts: TTS, + cut_method_names: Sequence[str], + control_callbacks: RuntimeControlCallbacks | None = None, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + ) -> None: + self.tts = tts + self.cut_method_names = set(cut_method_names) + self.control_callbacks = control_callbacks or RuntimeControlCallbacks() + self.reference_registry = ReferenceRegistry() + self.model_registry = ModelRegistry( + t2s_weights_path=str(self.tts.configs.t2s_weights_path), + vits_weights_path=str(self.tts.configs.vits_weights_path), + ) + self.scheduler_worker = UnifiedSchedulerWorker(tts, max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) + self.direct_tts_lock = threading.RLock() + self.management_lock = threading.RLock() + + def _normalize_lang(self, value: str | None) -> str | None: + if value in [None, ""]: + return value + return str(value).lower() + + def _apply_default_reference(self, req: dict) -> dict: + normalized = dict(req) + default_ref = self.reference_registry.get_default() + if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: + normalized["ref_audio_path"] = default_ref.ref_audio_path + if "text_lang" in normalized: + normalized["text_lang"] = self._normalize_lang(normalized.get("text_lang")) + if "prompt_lang" in normalized: + normalized["prompt_lang"] = self._normalize_lang(normalized.get("prompt_lang")) + return normalized + + def check_params(self, req: dict) -> Optional[str]: + text = req.get("text", "") + text_lang = req.get("text_lang", "") + ref_audio_path = req.get("ref_audio_path", "") + media_type = req.get("media_type", "wav") + prompt_lang = req.get("prompt_lang", "") + text_split_method = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return "ref_audio_path is required" + if text in [None, ""]: + return "text is required" + if text_lang in [None, ""]: + return "text_lang is required" + if text_lang.lower() not in self.tts.configs.languages: + return f"text_lang: {text_lang} is not supported in version {self.tts.configs.version}" + if prompt_lang in [None, ""]: + return "prompt_lang is required" + if prompt_lang.lower() not in self.tts.configs.languages: + return f"prompt_lang: {prompt_lang} is not supported in version {self.tts.configs.version}" + if media_type not in ["wav", "raw", "ogg", "aac"]: + return f"media_type: {media_type} is not supported" + if text_split_method not in self.cut_method_names: + return f"text_split_method:{text_split_method} is not supported" + return None + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + normalized = dict(req) + streaming_mode = normalized.get("streaming_mode", False) + return_fragment = normalized.get("return_fragment", False) + if streaming_mode is False: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 0: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 1 or streaming_mode is True: + normalized["streaming_mode"] = False + normalized["return_fragment"] = True + normalized["fixed_length_chunk"] = False + elif streaming_mode == 2: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 3: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = True + else: + raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") + normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) + return normalized + + def _iter_direct_tts_bytes(self, req: dict) -> Generator[bytes, None, None]: + media_type = req["media_type"] + with self.direct_tts_lock: + tts_generator = self.tts.run(req) + first_chunk = True + current_media_type = media_type + for sr, chunk in tts_generator: + if first_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + current_media_type = "raw" + first_chunk = False + yield pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + normalized = self._normalize_streaming_mode(self._apply_default_reference(req)) + error = self.check_params(normalized) + if error is not None: + raise ValueError(error) + media_type = normalized.get("media_type", "wav") + if normalized["response_streaming"]: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_direct_tts_bytes(normalized), + ) + with self.direct_tts_lock: + tts_generator = self.tts.run(normalized) + sr, audio_data = next(tts_generator) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=pack_audio(BytesIO(), audio_data, sr, media_type).getvalue(), + ) + + def build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, payload in enumerate(request_items): + req = self._apply_default_reference( + { + "text": payload["text"], + "text_lang": self._normalize_lang(payload["text_lang"]), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": self._normalize_lang(payload["prompt_lang"]), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + ) + error = self.check_params(req) + if error is not None: + raise ValueError(f"request[{index}] 参数非法: {error}") + specs.append( + SchedulerRequestSpec( + request_id=payload.get("request_id") or f"req_{index:03d}", + ref_audio_path=Path(req["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=req["prompt_lang"], + text=payload["text"], + text_lang=req["text_lang"], + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload.get("early_stop_num", -1)), + ready_step=int(payload.get("ready_step", 0)), + ) + ) + return specs + + def build_scheduler_submit_spec(self, payload: dict) -> SchedulerRequestSpec: + request_id = payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}" + req = self._apply_default_reference( + { + "text": payload["text"], + "text_lang": self._normalize_lang(payload["text_lang"]), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": self._normalize_lang(payload["prompt_lang"]), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": float(payload["speed_factor"]), + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": payload["media_type"], + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": int(payload["sample_steps"]), + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + ) + error = self.check_params(req) + if error is not None: + raise ValueError(f"request 参数非法: {error}") + return SchedulerRequestSpec( + request_id=request_id, + ref_audio_path=Path(req["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=req["prompt_lang"], + text=payload["text"], + text_lang=req["text_lang"], + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload.get("early_stop_num", -1)), + ready_step=0, + ) + + @staticmethod + def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + @staticmethod + def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + set_scheduler_seed(seed) + specs = self.build_scheduler_request_specs(request_items) + states = await self.scheduler_worker.prepare_states_batch_async(specs) + finished = run_scheduler_continuous(self.tts.t2s_model.model, states, max_steps=int(max_steps)) + return SchedulerDebugExecution( + payload={ + "message": "success", + "request_count": len(states), + "max_steps": int(max_steps), + "requests": self.summarize_scheduler_states(states), + "finished": self.summarize_scheduler_finished(finished), + } + ) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + request_start = time.perf_counter() + prepare_start = request_start + spec = self.build_scheduler_submit_spec(payload) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + state, prepare_exec_started_at, prepare_exec_finished_at = await self.scheduler_worker.prepare_state_profiled_async( + spec, + spec_ready_at, + ) + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) + prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) + prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile = dict(state.prepare_profile) + prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) + api_after_prepare_start = time.perf_counter() + loop = asyncio.get_running_loop() + done_future = loop.create_future() + job = self.scheduler_worker.submit( + state=state, + speed_factor=float(payload["speed_factor"]), + sample_steps=int(payload["sample_steps"]), + media_type=str(payload["media_type"]), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + ) + api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) + await asyncio.wait_for(done_future, timeout=float(payload.get("timeout_sec", 30.0))) + wait_return_at = time.perf_counter() + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 + headers = { + "X-Request-Id": job.request_id, + "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", + "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", + "X-Queue-Wait-Ms": f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000", + "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", + "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", + "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", + "X-Prepare-Admission-Wait-Ms": ( + f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" + if job.result is not None + else "0.000" + ), + "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", + "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", + "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", + "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", + "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", + "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", + "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", + "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", + "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", + "X-Pack-Ms": f"{pack_ms:.3f}", + "X-Worker-Total-Ms": f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000", + "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", + "X-Decode-Steps": str(int(job.result["decode_steps"])) if job.result is not None else "0", + "X-Sample-Rate": str(int(job.sample_rate)), + } + prepare_profile = job.result.get("prepare_profile", {}) if job.result is not None else {} + if job.result is not None: + headers.update( + { + "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", + "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get('text_cpu_parallel_workers', 0.0))), + "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", + "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", + "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", + "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", + "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get('worker_prepare_inflight_on_enter', 0.0))), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get('worker_prepare_peak_inflight', 0.0))), + } + ) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, + ) + headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" + headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" + headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" + return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=f"audio/{job.media_type}", headers=headers) + + def get_scheduler_state(self) -> dict: + return self.scheduler_worker.snapshot() + + def get_runtime_state(self) -> dict: + model_state = self.model_registry.snapshot() + default_ref = self.reference_registry.get_default() + scheduler_state = self.get_scheduler_state() + return { + "message": "success", + "default_reference": { + "ref_audio_path": default_ref.ref_audio_path, + "updated_at": default_ref.updated_at, + }, + "model_registry": { + "generation": model_state.generation, + "t2s_generation": model_state.t2s_generation, + "vits_generation": model_state.vits_generation, + "t2s_weights_path": model_state.t2s_weights_path, + "vits_weights_path": model_state.vits_weights_path, + "updated_at": model_state.updated_at, + }, + "worker_state": scheduler_state, + } + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec): + raise TimeoutError("scheduler worker did not drain before model reload") + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + if refer_audio_path in [None, ""]: + state = self.reference_registry.clear() + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + if not os.path.exists(str(refer_audio_path)): + raise FileNotFoundError(f"{refer_audio_path} not exists") + with self.management_lock: + with self.direct_tts_lock: + self.tts.set_ref_audio(str(refer_audio_path)) + state = self.reference_registry.set_default(str(refer_audio_path)) + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + + def set_gpt_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("gpt weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_t2s_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_t2s_reload(str(weights_path)) + return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation} + + def set_sovits_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("sovits weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_vits_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_vits_reload(str(weights_path)) + return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation} + + def handle_control(self, command: str) -> None: + if command == "restart": + if self.control_callbacks.restart is None: + os.execl(sys.executable, sys.executable, *sys.argv) + self.control_callbacks.restart() + return + if command == "exit": + if self.control_callbacks.exit is None: + os.kill(os.getpid(), signal.SIGTERM) + return + self.control_callbacks.exit() + return + raise ValueError(f"unsupported command: {command}") diff --git a/api_v2.py b/api_v2.py index 21511db3..fb17dfd6 100644 --- a/api_v2.py +++ b/api_v2.py @@ -123,6 +123,7 @@ from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine from pydantic import BaseModel import threading @@ -147,6 +148,14 @@ if config_path in [None, ""]: tts_config = TTS_Config(config_path) print(tts_config) tts_pipeline = TTS(tts_config) +tts_engine = UnifiedTTSEngine( + tts_pipeline, + cut_method_names=cut_method_names, + control_callbacks=RuntimeControlCallbacks( + restart=lambda: os.execl(sys.executable, sys.executable, *argv), + exit=lambda: os.kill(os.getpid(), signal.SIGTERM), + ), +) APP = FastAPI() @@ -377,70 +386,11 @@ async def tts_handle(req: dict): StreamingResponse: audio stream response. """ - streaming_mode = req.get("streaming_mode", False) - return_fragment = req.get("return_fragment", False) - media_type = req.get("media_type", "wav") - - check_res = check_params(req) - if check_res is not None: - return check_res - - if streaming_mode == 0: - streaming_mode = False - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 1: - streaming_mode = False - return_fragment = True - fixed_length_chunk = False - elif streaming_mode == 2: - streaming_mode = True - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 3: - streaming_mode = True - return_fragment = False - fixed_length_chunk = True - - else: - return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) - - req["streaming_mode"] = streaming_mode - req["return_fragment"] = return_fragment - req["fixed_length_chunk"] = fixed_length_chunk - - print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") - - streaming_mode = streaming_mode or return_fragment - - try: - tts_generator = tts_pipeline.run(req) - - if streaming_mode: - - def streaming_generator(tts_generator: Generator, media_type: str): - if_frist_chunk = True - for sr, chunk in tts_generator: - if if_frist_chunk and media_type == "wav": - yield wave_header_chunk(sample_rate=sr) - media_type = "raw" - if_frist_chunk = False - yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() - - # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" - return StreamingResponse( - streaming_generator( - tts_generator, - media_type, - ), - media_type=f"audio/{media_type}", - ) - - else: - sr, audio_data = next(tts_generator) - audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - return Response(audio_data, media_type=f"audio/{media_type}") + result = tts_engine.run_direct_tts(req) + if result.streaming: + return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") + return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") except Exception as e: return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) @@ -449,7 +399,11 @@ async def tts_handle(req: dict): async def control(command: str = None): if command is None: return JSONResponse(status_code=400, content={"message": "command is required"}) - handle_control(command) + try: + tts_engine.handle_control(command) + return JSONResponse(status_code=200, content={"message": "success"}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)}) @APP.get("/tts") @@ -517,10 +471,10 @@ async def tts_post_endpoint(request: TTS_Request): @APP.get("/set_refer_audio") async def set_refer_aduio(refer_audio_path: str = None): try: - tts_pipeline.set_ref_audio(refer_audio_path) + payload = tts_engine.set_refer_audio(refer_audio_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) # @APP.post("/set_refer_audio") @@ -545,24 +499,19 @@ async def set_refer_aduio(refer_audio_path: str = None): @APP.get("/set_gpt_weights") async def set_gpt_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) - tts_pipeline.init_t2s_weights(weights_path) + payload = tts_engine.set_gpt_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) - - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) @APP.get("/set_sovits_weights") async def set_sovits_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) - tts_pipeline.init_vits_weights(weights_path) + payload = tts_engine.set_sovits_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) if __name__ == "__main__": diff --git a/api_v3.py b/api_v3.py index 74bc7ac8..37d66977 100644 --- a/api_v3.py +++ b/api_v3.py @@ -144,6 +144,7 @@ from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( run_prefill_active_batch, run_scheduler_continuous, ) +from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from pydantic import BaseModel import threading @@ -169,6 +170,14 @@ if config_path in [None, ""]: tts_config = TTS_Config(config_path) print(tts_config) tts_pipeline = TTS(tts_config) +tts_engine = UnifiedTTSEngine( + tts_pipeline, + cut_method_names=cut_method_names, + control_callbacks=RuntimeControlCallbacks( + restart=lambda: os.execl(sys.executable, sys.executable, *argv), + exit=lambda: os.kill(os.getpid(), signal.SIGTERM), + ), +) APP = FastAPI() @@ -805,7 +814,7 @@ class SchedulerDebugWorker: time.sleep(self.micro_batch_wait_s) -scheduler_debug_worker = SchedulerDebugWorker(tts_pipeline) +scheduler_debug_worker = tts_engine.scheduler_worker def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): @@ -1116,20 +1125,12 @@ def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerR async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): try: - set_scheduler_seed(request.seed) - specs = build_scheduler_request_specs(request.requests) - states = await scheduler_debug_worker.prepare_states_batch_async(specs) - finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) - return JSONResponse( - status_code=200, - content={ - "message": "success", - "request_count": len(states), - "max_steps": int(request.max_steps), - "requests": summarize_scheduler_states(states), - "finished": summarize_scheduler_finished(finished), - }, + result = await tts_engine.run_scheduler_debug( + request_items=[item.dict() for item in request.requests], + max_steps=int(request.max_steps), + seed=int(request.seed), ) + return JSONResponse(status_code=200, content=result.payload) except Exception as e: return JSONResponse( status_code=400, @@ -1139,206 +1140,8 @@ async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): try: - request_start = time.perf_counter() - prepare_start = request_start - spec = build_scheduler_submit_spec(request) - spec_ready_at = time.perf_counter() - prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) - state, prepare_exec_started_at, prepare_exec_finished_at = await scheduler_debug_worker.prepare_state_profiled_async( - spec, - spec_ready_at, - ) - prepare_end = time.perf_counter() - prepare_wall_ms = (prepare_end - prepare_start) * 1000.0 - prepare_profile_total_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) - prepare_profile_wall_ms = float(state.prepare_profile.get("wall_total_ms", prepare_profile_total_ms)) - prepare_executor_queue_ms = float( - state.prepare_profile.get("executor_queue_ms", max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)) - ) - prepare_executor_run_ms = float( - state.prepare_profile.get( - "executor_run_wall_ms", - max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0), - ) - ) - prepare_other_ms = max( - 0.0, - prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_profile_wall_ms, - ) - loop = asyncio.get_running_loop() - done_future = loop.create_future() - job = scheduler_debug_worker.submit( - state, - speed_factor=float(request.speed_factor), - sample_steps=int(request.sample_steps), - media_type=request.media_type, - prepare_wall_ms=prepare_wall_ms, - prepare_profile_total_ms=prepare_profile_total_ms, - done_loop=loop, - done_future=done_future, - ) - api_after_prepare_ms = max(0.0, (job.enqueue_time - prepare_end) * 1000.0) - timeout_ok = False - try: - await asyncio.wait_for(asyncio.shield(done_future), timeout=float(request.timeout_sec)) - timeout_ok = True - except asyncio.TimeoutError: - timeout_ok = False - wait_return_at = time.perf_counter() - if not timeout_ok: - return JSONResponse( - status_code=202, - content={ - "message": "queued", - "request_id": job.request_id, - "timings": { - "prepare_ms": prepare_wall_ms, - "prepare_wall_ms": prepare_wall_ms, - "prepare_profile_total_ms": prepare_profile_total_ms, - "api_after_prepare_ms": api_after_prepare_ms, - "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), - }, - "worker_state": scheduler_debug_worker.get_state(), - }, - ) - if job.error is not None: - return JSONResponse( - status_code=400, - content={"message": "scheduler submit failed", "request_id": job.request_id, "Exception": job.error}, - ) - if job.audio_data is None or job.sample_rate is None: - return JSONResponse( - status_code=500, - content={ - "message": "scheduler submit failed", - "request_id": job.request_id, - "Exception": "job finished without audio payload", - }, - ) - pack_start = time.perf_counter() - audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() - pack_end = time.perf_counter() - pack_ms = (pack_end - pack_start) * 1000.0 - job.pack_ms = pack_ms - api_wait_result_ms = 0.0 - if job.result_ready_time is not None: - api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) - worker_total_ms = float(job.result["worker_total_ms"]) if job.result is not None else 0.0 - headers = { - "X-Request-Id": job.request_id, - "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", - "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", - "X-Queue-Wait-Ms": ( - f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" - ), - "X-Prepare-Ms": f"{prepare_wall_ms:.3f}", - "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", - "X-Prepare-Spec-Build-Ms": f"{prepare_spec_build_ms:.3f}", - "X-Prepare-Executor-Queue-Ms": f"{prepare_executor_queue_ms:.3f}", - "X-Prepare-Admission-Wait-Ms": ( - f"{float(job.result['prepare_profile'].get('prepare_admission_wait_ms', 0.0)):.3f}" - if job.result is not None - else "0.000" - ), - "X-Prepare-Executor-Run-Ms": f"{prepare_executor_run_ms:.3f}", - "X-Prepare-Profile-Total-Ms": f"{prepare_profile_total_ms:.3f}", - "X-Prepare-Profile-Wall-Ms": f"{prepare_profile_wall_ms:.3f}", - "X-Prepare-Other-Ms": f"{prepare_other_ms:.3f}", - "X-Api-After-Prepare-Ms": f"{api_after_prepare_ms:.3f}", - "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", - "X-Merge-Ms": f"{float(job.result['merge_ms']):.3f}" if job.result is not None else "0.000", - "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", - "X-Finalize-Wait-Ms": f"{float(job.result['finalize_wait_ms']):.3f}" if job.result is not None else "0.000", - "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", - "X-Worker-Residual-Ms": f"{float(job.result['worker_residual_ms']):.3f}" if job.result is not None else "0.000", - "X-Worker-Other-Ms": f"{float(job.result['worker_other_ms']):.3f}" if job.result is not None else "0.000", - "X-Pack-Ms": f"{pack_ms:.3f}", - "X-Worker-Total-Ms": ( - f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" - ), - "X-Api-Wait-Result-Ms": f"{api_wait_result_ms:.3f}", - "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", - } - if job.result is not None: - prepare_profile = job.result.get("prepare_profile", {}) - headers.update( - { - "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('prompt_text_cpu_preprocess_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-CPU-Preprocess-Ms": f"{float(prepare_profile.get('text_cpu_preprocess_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('prompt_text_cpu_queue_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-CPU-Queue-Ms": f"{float(prepare_profile.get('text_cpu_queue_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('prompt_text_feature_queue_ms', 0.0)):.3f}", - "X-Prepare-Target-Text-Feature-Queue-Ms": f"{float(prepare_profile.get('text_feature_queue_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_admission_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Admission-Wait-Ms": f"{float(prepare_profile.get('text_bert_admission_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_queue_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Queue-Wait-Ms": f"{float(prepare_profile.get('text_bert_queue_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_collect_wait_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": f"{float(prepare_profile.get('text_bert_batch_collect_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str( - int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0)) - ), - "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str( - int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str( - int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0)) - ), - "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str( - int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-High-Pressure-Peak": str( - int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0)) - ), - "X-Prepare-Target-Bert-High-Pressure-Peak": str( - int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( - int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) - ), - "X-Prepare-Target-Bert-Batch-Size-Peak": str( - int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) - ), - "X-Prepare-Prompt-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('prompt_text_bert_batch_window_ms', 0.0)):.3f}", - "X-Prepare-Target-Bert-Batch-Window-Ms": f"{float(prepare_profile.get('text_bert_batch_window_ms', 0.0)):.3f}", - "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", - "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), - "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", - "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", - "X-Prepare-Prompt-Semantic-Batch-Size": str( - int(prepare_profile.get("prompt_semantic_batch_size", 0.0)) - ), - "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", - "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", - "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", - "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", - "X-Prepare-Inflight-On-Enter": str( - int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) - ), - "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), - } - ) - response_ready_at = time.perf_counter() - response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) - request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) - request_other_ms = max( - 0.0, - request_total_ms - prepare_wall_ms - api_after_prepare_ms - worker_total_ms - api_wait_result_ms - pack_ms, - ) - headers["X-Response-Overhead-Ms"] = f"{response_overhead_ms:.3f}" - headers["X-Request-Other-Ms"] = f"{request_other_ms:.3f}" - headers["X-Request-Total-Ms"] = f"{request_total_ms:.3f}" - return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) + result = await tts_engine.run_scheduler_submit(request.dict()) + return Response(result.audio_bytes, media_type=result.media_type, headers=result.headers) except Exception as e: return JSONResponse( status_code=400, @@ -1381,70 +1184,11 @@ async def tts_handle(req: dict): StreamingResponse: audio stream response. """ - streaming_mode = req.get("streaming_mode", False) - return_fragment = req.get("return_fragment", False) - media_type = req.get("media_type", "wav") - - check_res = check_params(req) - if check_res is not None: - return check_res - - if streaming_mode == 0: - streaming_mode = False - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 1: - streaming_mode = False - return_fragment = True - fixed_length_chunk = False - elif streaming_mode == 2: - streaming_mode = True - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 3: - streaming_mode = True - return_fragment = False - fixed_length_chunk = True - - else: - return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) - - req["streaming_mode"] = streaming_mode - req["return_fragment"] = return_fragment - req["fixed_length_chunk"] = fixed_length_chunk - - print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") - - streaming_mode = streaming_mode or return_fragment - - try: - tts_generator = tts_pipeline.run(req) - - if streaming_mode: - - def streaming_generator(tts_generator: Generator, media_type: str): - if_frist_chunk = True - for sr, chunk in tts_generator: - if if_frist_chunk and media_type == "wav": - yield wave_header_chunk(sample_rate=sr) - media_type = "raw" - if_frist_chunk = False - yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() - - # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" - return StreamingResponse( - streaming_generator( - tts_generator, - media_type, - ), - media_type=f"audio/{media_type}", - ) - - else: - sr, audio_data = next(tts_generator) - audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - return Response(audio_data, media_type=f"audio/{media_type}") + result = tts_engine.run_direct_tts(req) + if result.streaming: + return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") + return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") except Exception as e: return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) @@ -1453,7 +1197,11 @@ async def tts_handle(req: dict): async def control(command: str = None): if command is None: return JSONResponse(status_code=400, content={"message": "command is required"}) - handle_control(command) + try: + tts_engine.handle_control(command) + return JSONResponse(status_code=200, content={"message": "success"}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)}) @APP.get("/tts") @@ -1530,16 +1278,16 @@ async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request): @APP.get("/tts_scheduler_state") async def tts_scheduler_state_endpoint(): - return JSONResponse(status_code=200, content={"message": "success", "worker_state": scheduler_debug_worker.get_state()}) + return JSONResponse(status_code=200, content=tts_engine.get_runtime_state()) @APP.get("/set_refer_audio") async def set_refer_aduio(refer_audio_path: str = None): try: - tts_pipeline.set_ref_audio(refer_audio_path) + payload = tts_engine.set_refer_audio(refer_audio_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) # @APP.post("/set_refer_audio") @@ -1564,24 +1312,19 @@ async def set_refer_aduio(refer_audio_path: str = None): @APP.get("/set_gpt_weights") async def set_gpt_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) - tts_pipeline.init_t2s_weights(weights_path) + payload = tts_engine.set_gpt_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) - - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) @APP.get("/set_sovits_weights") async def set_sovits_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) - tts_pipeline.init_vits_weights(weights_path) + payload = tts_engine.set_sovits_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) if __name__ == "__main__":