From 800acd45ff554dd63f8de365e111c625c8923be6 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sat, 7 Mar 2026 05:47:22 +0800 Subject: [PATCH 1/6] Enhance G2P processing by implementing batch input handling in _g2p function, improving efficiency. Update prepare_onnx_input to utilize caching for tokenization and add optional parameters for character ID mapping and phoneme masks. Refactor G2PWOnnxConverter to streamline model loading and configuration management. --- GPT_SoVITS/text/chinese2.py | 17 ++- GPT_SoVITS/text/g2pw/dataset.py | 107 ++++++++++++++---- GPT_SoVITS/text/g2pw/onnx_api.py | 181 +++++++++++++++++++++++-------- 3 files changed, 233 insertions(+), 72 deletions(-) diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index dcce0d96..acfebfe2 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -180,10 +180,15 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> def _g2p(segments): phones_list = [] word2ph = [] - for seg in segments: + g2pw_batch_results = [] + g2pw_batch_cursor = 0 + processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments] + if is_g2pw: + batch_inputs = [seg for seg in processed_segments if seg] + g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else [] + + for seg in processed_segments: pinyins = [] - # Replace all English words in the sentence - seg = re.sub("[a-zA-Z]+", "", seg) seg_cut = psg.lcut(seg) seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) initials = [] @@ -204,8 +209,10 @@ def _g2p(segments): finals = sum(finals, []) print("pypinyin结果", initials, finals) else: - # g2pw采用整句推理 - pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3) + # g2pw采用整句推理(批量推理,逐句取结果) + if seg: + pinyins = g2pw_batch_results[g2pw_batch_cursor] + g2pw_batch_cursor += 1 pre_word_length = 0 for word, pos in seg_cut: diff --git a/GPT_SoVITS/text/g2pw/dataset.py b/GPT_SoVITS/text/g2pw/dataset.py index ff09cbc2..e464c29a 100644 --- a/GPT_SoVITS/text/g2pw/dataset.py +++ b/GPT_SoVITS/text/g2pw/dataset.py @@ -18,6 +18,7 @@ Credits from typing import Dict from typing import List +from typing import Optional from typing import Tuple import numpy as np @@ -37,6 +38,8 @@ def prepare_onnx_input( use_mask: bool = False, window_size: int = None, max_len: int = 512, + char2id: Optional[Dict[str, int]] = None, + char_phoneme_masks: Optional[Dict[str, List[int]]] = None, ) -> Dict[str, np.array]: if window_size is not None: truncated_texts, truncated_query_ids = _truncate_texts( @@ -48,33 +51,88 @@ def prepare_onnx_input( phoneme_masks = [] char_ids = [] position_ids = [] + tokenized_cache = {} + + if char2id is None: + char2id = {char: idx for idx, char in enumerate(chars)} + if use_mask: + if char_phoneme_masks is None: + char_phoneme_masks = { + char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))] + for char in char2phonemes + } + else: + full_phoneme_mask = [1] * len(labels) for idx in range(len(texts)): text = (truncated_texts if window_size else texts)[idx].lower() query_id = (truncated_query_ids if window_size else query_ids)[idx] - try: - tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) - except Exception: - print(f'warning: text "{text}" is invalid') - return {} + cached = tokenized_cache.get(text) + if cached is None: + try: + tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} - text, query_id, tokens, text2token, token2text = _truncate( - max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text - ) + if len(tokens) <= max_len - 2: + processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) + cached = { + "is_short": True, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + "input_id": shared_input_id, + "token_type_id": shared_token_type_id, + "attention_mask": shared_attention_mask, + } + else: + cached = { + "is_short": False, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + } + tokenized_cache[text] = cached - processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + if cached["is_short"]: + text_for_query = text + query_id_for_query = query_id + text2token_for_query = cached["text2token"] + input_id = cached["input_id"] + token_type_id = cached["token_type_id"] + attention_mask = cached["attention_mask"] + else: + ( + text_for_query, + query_id_for_query, + tokens_for_query, + text2token_for_query, + _token2text_for_query, + ) = _truncate( + max_len=max_len, + text=text, + query_id=query_id, + tokens=cached["tokens"], + text2token=cached["text2token"], + token2text=cached["token2text"], + ) + processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"] + input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) - token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) - attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - - query_char = text[query_id] - phoneme_mask = ( - [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels) - ) - char_id = chars.index(query_char) - position_id = text2token[query_id] + 1 # [CLS] token locate at first place + query_char = text_for_query[query_id_for_query] + if use_mask: + phoneme_mask = char_phoneme_masks[query_char] + else: + phoneme_mask = full_phoneme_mask + char_id = char2id[query_char] + position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place input_ids.append(input_id) token_type_ids.append(token_type_id) @@ -83,10 +141,15 @@ def prepare_onnx_input( char_ids.append(char_id) position_ids.append(position_id) + max_token_length = max(len(seq) for seq in input_ids) + + def _pad_sequences(sequences, pad_value=0): + return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences] + outputs = { - "input_ids": np.array(input_ids).astype(np.int64), - "token_type_ids": np.array(token_type_ids).astype(np.int64), - "attention_masks": np.array(attention_masks).astype(np.int64), + "input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64), + "token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64), + "attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64), "phoneme_masks": np.array(phoneme_masks).astype(np.float32), "char_ids": np.array(char_ids).astype(np.int64), "position_ids": np.array(position_ids).astype(np.int64), diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 1d5e4231..5fcf1ae2 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple import numpy as np import onnxruntime import requests -import torch from opencc import OpenCC from pypinyin import Style, pinyin from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -22,9 +21,8 @@ from .utils import load_config onnxruntime.set_default_logger_severity(3) try: onnxruntime.preload_dlls() -except: +except Exception: pass - # traceback.print_exc() warnings.filterwarnings("ignore") model_version = "1.1" @@ -55,6 +53,24 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis return all_preds, all_confidences +def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]: + for candidate_dir in candidate_dirs: + if not candidate_dir: + continue + json_path = os.path.join(candidate_dir, filename) + if os.path.exists(json_path): + with open(json_path, "r", encoding="utf-8") as fr: + return json.load(fr) + raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}") + + +def _find_first_existing_file(*paths: str) -> str: + for path in paths: + if path and os.path.exists(path): + return path + raise FileNotFoundError(f"Files not found: {paths}") + + def download_and_decompress(model_dir: str = "G2PWModel/"): if not os.path.exists(model_dir): parent_directory = os.path.dirname(model_dir) @@ -62,7 +78,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): extract_dir = os.path.join(parent_directory, "G2PWModel_1.1") extract_dir_new = os.path.join(parent_directory, "G2PWModel") print("Downloading g2pw model...") - modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip" + modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" with requests.get(modelscope_url, stream=True) as r: r.raise_for_status() with open(zip_dir, "wb") as f: @@ -79,7 +95,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): return model_dir -class G2PWOnnxConverter: +class _G2PWBaseOnnxConverter: def __init__( self, model_dir: str = "G2PWModel/", @@ -87,33 +103,16 @@ class G2PWOnnxConverter: model_source: str = None, enable_non_tradional_chinese: bool = False, ): - uncompress_path = download_and_decompress(model_dir) - - sess_options = onnxruntime.SessionOptions() - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL - sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0 - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], - ) - else: - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CPUExecutionProvider"], - ) - self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True) + self.model_dir = download_and_decompress(model_dir) + self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True) self.model_source = model_source if model_source else self.config.model_source self.enable_opencc = enable_non_tradional_chinese - self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) - polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt") - monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt") + polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt") + monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt") + self.polyphonic_chars = [ line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n") ] @@ -149,31 +148,45 @@ class G2PWOnnxConverter: ) self.chars = sorted(list(self.char2phonemes.keys())) + self.char2id = {char: idx for idx, char in enumerate(self.chars)} + self.char_phoneme_masks = ( + { + char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))] + for char in self.char2phonemes + } + if self.config.use_mask + else None + ) self.polyphonic_chars_new = set(self.chars) for char in self.non_polyphonic: - if char in self.polyphonic_chars_new: - self.polyphonic_chars_new.remove(char) + self.polyphonic_chars_new.discard(char) self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars} for char in self.non_monophonic: - if char in self.monophonic_chars_dict: - self.monophonic_chars_dict.pop(char) + self.monophonic_chars_dict.pop(char, None) - self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"] + default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel")) + candidate_asset_dirs = [self.model_dir, default_asset_dir] + self.bopomofo_convert_dict = _load_json_from_candidates( + "bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs + ) + self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs) - with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr: - self.bopomofo_convert_dict = json.load(fr) self.style_convert_func = { "bopomofo": lambda x: x, "pinyin": self._convert_bopomofo_to_pinyin, }[style] - with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr: - self.char_bopomofo_dict = json.load(fr) - if self.enable_opencc: self.cc = OpenCC("s2tw") + self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in { + "1", + "true", + "yes", + "y", + "on", + } def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] @@ -181,9 +194,8 @@ class G2PWOnnxConverter: component = self.bopomofo_convert_dict.get(bopomofo[:-1]) if component: return component + tone - else: - print(f'Warning: "{bopomofo}" cannot convert to pinyin') - return None + print(f'Warning: "{bopomofo}" cannot convert to pinyin') + return None def __call__(self, sentences: List[str]) -> List[List[str]]: if isinstance(sentences, str): @@ -199,10 +211,9 @@ class G2PWOnnxConverter: texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: - # sentences no polyphonic words return partial_results - onnx_input = prepare_onnx_input( + model_input = prepare_onnx_input( tokenizer=self.tokenizer, labels=self.labels, char2phonemes=self.char2phonemes, @@ -211,9 +222,18 @@ class G2PWOnnxConverter: query_ids=query_ids, use_mask=self.config.use_mask, window_size=None, + char2id=self.char2id, + char_phoneme_masks=self.char_phoneme_masks, ) - preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels) + if not model_input: + return partial_results + + if self.enable_sentence_dedup: + preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts) + else: + preds, _confidences = self._predict(model_input=model_input) + if self.config.use_char_phoneme: preds = [pred.split(" ")[1] for pred in preds] @@ -226,7 +246,6 @@ class G2PWOnnxConverter: def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: texts, query_ids, sent_ids, partial_results = [], [], [], [] for sent_id, sent in enumerate(sentences): - # pypinyin works well for Simplified Chinese than Traditional Chinese sent_s = tranditional_to_simplified(sent) pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) partial_result = [None] * len(sent) @@ -239,9 +258,81 @@ class G2PWOnnxConverter: partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: partial_result[i] = pypinyin_result[i][0] - # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) else: partial_result[i] = pypinyin_result[i][0] partial_results.append(partial_result) return texts, query_ids, sent_ids, partial_results + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + raise NotImplementedError + + def _predict_with_sentence_dedup( + self, model_input: Dict[str, Any], texts: List[str] + ) -> Tuple[List[str], List[float]]: + if len(texts) <= 1: + return self._predict(model_input=model_input) + + grouped_indices: Dict[str, List[int]] = {} + for idx, text in enumerate(texts): + grouped_indices.setdefault(text, []).append(idx) + + if all(len(indices) == 1 for indices in grouped_indices.values()): + return self._predict(model_input=model_input) + + preds: List[str] = [""] * len(texts) + confidences: List[float] = [0.0] * len(texts) + for indices in grouped_indices.values(): + group_input = {name: value[indices] for name, value in model_input.items()} + if len(indices) > 1: + for name in ("input_ids", "token_type_ids", "attention_masks"): + group_input[name] = group_input[name][:1] + + group_preds, group_confidences = self._predict(model_input=group_input) + for output_idx, pred, confidence in zip(indices, group_preds, group_confidences): + preds[output_idx] = pred + confidences[output_idx] = confidence + + return preds, confidences + + +class G2PWOnnxConverter(_G2PWBaseOnnxConverter): + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + super().__init__( + model_dir=model_dir, + style=style, + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL + sess_options.intra_op_num_threads = 2 + + onnx_path = _find_first_existing_file( + os.path.join(self.model_dir, "g2pW.onnx"), + os.path.join(self.model_dir, "g2pw.onnx"), + ) + + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + else: + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels) From b250e62402893b4f355d58599cd946150880ef7f Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sun, 8 Mar 2026 03:01:20 +0800 Subject: [PATCH 2/6] Enhance G2PW model input handling by introducing polyphonic context character support and updating the data preparation method to return additional query IDs. This improves the processing of polyphonic characters in sentences. --- GPT_SoVITS/text/g2pw/onnx_api.py | 37 ++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 5fcf1ae2..3c2b0169 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -187,6 +187,8 @@ class _G2PWBaseOnnxConverter: "y", "on", } + # 聚焦到多音字附近上下文,默认左右各16字;设为0表示关闭裁剪(整句)。 + self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16"))) def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] @@ -209,7 +211,7 @@ class _G2PWBaseOnnxConverter: translated_sentences.append(translated_sent) sentences = translated_sentences - texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) + texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: return partial_results @@ -219,7 +221,7 @@ class _G2PWBaseOnnxConverter: char2phonemes=self.char2phonemes, chars=self.chars, texts=texts, - query_ids=query_ids, + query_ids=model_query_ids, use_mask=self.config.use_mask, window_size=None, char2id=self.char2id, @@ -238,22 +240,23 @@ class _G2PWBaseOnnxConverter: preds = [pred.split(" ")[1] for pred in preds] results = partial_results - for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): + for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds): results[sent_id][query_id] = self.style_convert_func(pred) return results - def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: - texts, query_ids, sent_ids, partial_results = [], [], [], [] + def _prepare_data( + self, sentences: List[str] + ) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]: + texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], [] for sent_id, sent in enumerate(sentences): sent_s = tranditional_to_simplified(sent) pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) partial_result = [None] * len(sent) + polyphonic_indices: List[int] = [] for i, char in enumerate(sent): if char in self.polyphonic_chars_new: - texts.append(sent) - query_ids.append(i) - sent_ids.append(sent_id) + polyphonic_indices.append(i) elif char in self.monophonic_chars_dict: partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: @@ -261,8 +264,24 @@ class _G2PWBaseOnnxConverter: else: partial_result[i] = pypinyin_result[i][0] + if polyphonic_indices: + if self.polyphonic_context_chars > 0: + left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars) + right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1) + sent_for_predict = sent[left:right] + query_offset = left + else: + sent_for_predict = sent + query_offset = 0 + + for index in polyphonic_indices: + texts.append(sent_for_predict) + model_query_ids.append(index - query_offset) + result_query_ids.append(index) + sent_ids.append(sent_id) + partial_results.append(partial_result) - return texts, query_ids, sent_ids, partial_results + return texts, model_query_ids, result_query_ids, sent_ids, partial_results def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: raise NotImplementedError From 30a4557d8db2ebbac91dc264a063cc7175172932 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Sun, 8 Mar 2026 23:08:27 +0800 Subject: [PATCH 3/6] Implement last inference statistics tracking in Text2SemanticDecoder and enhance TTS class with prompt semantic extraction. This includes methods for setting and retrieving inference stats, as well as improvements to audio processing and feature extraction in TTS. --- GPT_SoVITS/AR/models/t2s_model.py | 76 +++++++++++++++++++++++++++++-- GPT_SoVITS/TTS_infer_pack/TTS.py | 38 ++++++++++++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ac905f4b..f55b7508 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module): blocks.append(block) self.t2s_transformer = T2STransformer(self.num_layers, blocks) + self.last_infer_stats = {} + + def _set_last_infer_stats(self, stats): + self.last_infer_stats = stats + + def get_last_infer_stats(self): + return dict(self.last_infer_stats) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): + requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True)) if prompts is None: + self._set_last_infer_stats( + { + "infer_mode": "batch_infer_prompt_free_fallback", + "requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": 0, + "generated_token_count_list": [], + } + ) print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") return self.infer_panel_naive_batched( x, @@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module): ) max_len = kwargs.get("max_len", x_lens.max()) + enable_mask_free_fastpath = requested_enable_mask_free_fastpath x_list = [] for x_item, bert_item in zip(x, bert_feature): # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) @@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module): y_list = [None] * y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None] * y.shape[0] + decode_attn_mask = attn_mask + prefill_after_mask_all_visible = None + fastpath_hit = False for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token( + xy_pos, k_cache, v_cache, decode_attn_mask + ) logits = self.ar_predict_layer(xy_dec[:, -1]) if idx == 0: attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False) + prefill_after_mask_all_visible = not attn_mask.any().item() + if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible: + decode_attn_mask = None + fastpath_hit = True + else: + decode_attn_mask = attn_mask else: - attn_mask = F.pad(attn_mask, (0, 1), value=False) + if decode_attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1), value=False) + decode_attn_mask = attn_mask if idx < 11: ###至少预测出10个token不然不给停止(0.4s) logits = logits[:, :-1] @@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module): if reserved_idx_of_batch_for_y is not None: # index = torch.LongTensor(batch_idx_map).to(y.device) y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + if decode_attn_mask is not None: + attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + decode_attn_mask = attn_mask if k_cache is not None: for i in range(len(k_cache)): k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) @@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module): if idx_list[i] is None: idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替 + self._set_last_infer_stats( + { + "infer_mode": "batch_infer", + "requested_enable_mask_free_fastpath": enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": prefill_after_mask_all_visible, + "fastpath_hit": fastpath_hit, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + "max_len": int(max_len), + } + ) if ref_free: return y_list, [0] * x.shape[0] # print(idx_list) @@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module): y_list.append(y[0]) idx_list.append(idx) + self._set_last_infer_stats( + { + "infer_mode": "naive_batched", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + } + ) return y_list, idx_list def infer_panel_naive( @@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module): if not streaming_mode: + generated_token_count = max(int(y.shape[1] - prefix_len), 0) + self._set_last_infer_stats( + { + "infer_mode": "naive", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(x.shape[0]), + "prefill_after_mask_all_visible": True if prompts is not None else None, + "fastpath_hit": True if prompts is not None else False, + "generated_token_count": generated_token_count, + "generated_token_count_list": [generated_token_count], + } + ) if ref_free: yield y, 0 yield y, idx diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c8344b0..e1efd973 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1227,6 +1227,9 @@ class TTS: ###### inference ###### t_34 = 0.0 t_45 = 0.0 + t2s_observe_batch_count = 0 + t2s_observe_fastpath_hits = 0 + t2s_observe_generated_tokens = 0 audio = [] is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] @@ -1280,6 +1283,29 @@ class TTS: ) t4 = time.perf_counter() t_34 += t4 - t3 + if hasattr(self.t2s_model.model, "get_last_infer_stats"): + t2s_stats = self.t2s_model.model.get_last_infer_stats() + if t2s_stats: + generated_token_count = int(t2s_stats.get("generated_token_count", 0)) + t2s_total_ms = (t4 - t3) * 1000.0 + avg_decode_ms_per_token = ( + t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0 + ) + t2s_observe_batch_count += 1 + t2s_observe_generated_tokens += generated_token_count + if bool(t2s_stats.get("fastpath_hit", False)): + t2s_observe_fastpath_hits += 1 + print( + "[t2s_observe] " + f"mode={t2s_stats.get('infer_mode')} " + f"batch_size={t2s_stats.get('batch_size')} " + f"tokens={generated_token_count} " + f"t2s_ms={t2s_total_ms:.3f} " + f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} " + f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} " + f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} " + f"fastpath_hit={t2s_stats.get('fastpath_hit')}" + ) batch_audio_fragment = [] @@ -1500,6 +1526,18 @@ class TTS: if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + if t2s_observe_batch_count > 0: + request_avg_decode_ms_per_token = ( + (t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0 + ) + print( + "[t2s_request_observe] " + f"batches={t2s_observe_batch_count} " + f"fastpath_hits={t2s_observe_fastpath_hits} " + f"generated_tokens={t2s_observe_generated_tokens} " + f"t2s_total_ms={t_34 * 1000.0:.3f} " + f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}" + ) if len(audio) == 0: yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return From dc37b0b9ef9f29b852fa37e6cd661369b60d6f86 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 00:22:59 +0800 Subject: [PATCH 4/6] Add WebAPI documentation and implement TTS API with endpoints for text-to-speech inference, control commands, and model switching. Enhance TTS class with methods for extracting prompt semantics and reference audio specifications. Introduce a scheduler prototype for managing T2S requests. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 190 +++- GPT_SoVITS/TTS_infer_pack/__init__.py | 12 +- GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 631 +++++++++++ api_v3.py | 1170 ++++++++++++++++++++ tools/t2s_scheduler_prototype.py | 180 +++ 5 files changed, 2147 insertions(+), 36 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py create mode 100644 api_v3.py create mode 100644 tools/t2s_scheduler_prototype.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index e1efd973..bd4953df 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -759,21 +759,35 @@ class TTS: self._set_ref_spec(ref_audio_path) self._set_ref_audio_path(ref_audio_path) - def _set_ref_audio_path(self, ref_audio_path): - self.prompt_cache["ref_audio_path"] = ref_audio_path + def extract_prompt_semantic(self, ref_wav_path: str): + zero_wav = np.zeros( + int(self.configs.sampling_rate * 0.3), + dtype=np.float16 if self.configs.is_half else np.float32, + ) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + wav16k = wav16k.to(self.configs.device) + zero_wav_torch = zero_wav_torch.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + zero_wav_torch = zero_wav_torch.half() - def _set_ref_spec(self, ref_audio_path): - spec_audio = self._get_ref_spec(ref_audio_path) - if self.prompt_cache["refer_spec"] in [[], None]: - self.prompt_cache["refer_spec"] = [spec_audio] - else: - self.prompt_cache["refer_spec"][0] = spec_audio + wav16k = torch.cat([wav16k, zero_wav_torch]) + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( + 1, 2 + ) # .float() + codes = self.vits_model.extract_latent(hubert_feature) - def _get_ref_spec(self, ref_audio_path): + prompt_semantic = codes[0, 0].to(self.configs.device) + return prompt_semantic + + def extract_ref_spec(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) raw_audio = raw_audio.to(self.configs.device).float() - self.prompt_cache["raw_audio"] = raw_audio - self.prompt_cache["raw_sr"] = raw_sr if raw_sr != self.configs.sampling_rate: audio = raw_audio.to(self.configs.device) @@ -804,33 +818,30 @@ class TTS: audio = audio.half() else: audio = None + return spec, audio, raw_audio, raw_sr + + def extract_text_features(self, text: str, language: str): + return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + + def _set_ref_audio_path(self, ref_audio_path): + self.prompt_cache["ref_audio_path"] = ref_audio_path + + def _set_ref_spec(self, ref_audio_path): + spec_audio = self._get_ref_spec(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [spec_audio] + else: + self.prompt_cache["refer_spec"][0] = spec_audio + + def _get_ref_spec(self, ref_audio_path): + spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path) + self.prompt_cache["raw_audio"] = raw_audio + self.prompt_cache["raw_sr"] = raw_sr return spec, audio def _set_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - self.prompt_cache["prompt_semantic"] = prompt_semantic + prompt_semantic = self.extract_prompt_semantic(ref_wav_path) + self.prompt_cache["prompt_semantic"] = prompt_semantic def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): seq = sequences[0] @@ -1701,6 +1712,115 @@ class TTS: return audio + def using_vocoder_synthesis_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_audio_spec: torch.Tensor, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device) + refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device) + + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio = raw_audio.to(self.configs.device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + tgt_sr = 24000 if self.configs.version == "v3" else 32000 + if raw_sr != tgt_sr: + ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device) + + mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4 + mel2 = mel_spec_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + T_ref = self.vocoder_configs["T_ref"] + T_chunk = self.vocoder_configs["T_chunk"] + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min + + mel2 = mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + + cfm_res = self.vits_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + + with torch.inference_mode(): + wav_gen = self.vocoder(cfm_res) + audio = wav_gen[0][0] + + return audio + + def synthesize_audio_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_spec: tuple, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + refer_audio_spec, audio_tensor = refer_spec + if not self.configs.use_vocoder: + refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)] + sv_emb = None + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) + sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device) + return self.vits_model.decode( + semantic_tokens, + phones, + refer_audio_spec_list, + speed=speed, + sv_emb=sv_emb, + ).detach()[0, 0, :] + + return self.using_vocoder_synthesis_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_audio_spec=refer_audio_spec, + raw_audio=raw_audio, + raw_sr=raw_sr, + speed=speed, + sample_steps=sample_steps, + ) + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/__init__.py b/GPT_SoVITS/TTS_infer_pack/__init__.py index 8579a632..09a257b2 100644 --- a/GPT_SoVITS/TTS_infer_pack/__init__.py +++ b/GPT_SoVITS/TTS_infer_pack/__init__.py @@ -1 +1,11 @@ -from . import TTS, text_segmentation_method +from __future__ import annotations + +import importlib + +__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"] + + +def __getattr__(name: str): + if name in __all__: + return importlib.import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py new file mode 100644 index 00000000..e94a72c7 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -0,0 +1,631 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F + +from AR.models.utils import make_pad_mask_left, sample + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +@dataclass +class SchedulerRequestSpec: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int = 0 + + +@dataclass +class T2SRequestState: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + norm_prompt_text: str + norm_text: str + phones: torch.LongTensor + prompt_phones: torch.LongTensor + all_phones: torch.LongTensor + all_bert_features: torch.Tensor + prompt_semantic: torch.LongTensor + refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]] + raw_audio: torch.Tensor + raw_sr: int + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int + prepare_profile: Dict[str, float] + + +@dataclass +class T2SRunningRequest: + state: T2SRequestState + y_sequence: torch.LongTensor + prefix_len: int + decode_attn_mask: torch.Tensor + k_cache: List[torch.Tensor] + v_cache: List[torch.Tensor] + step_idx: int + + +@dataclass +class T2SFinishedItem: + request_id: str + semantic_tokens: torch.LongTensor + finish_idx: int + finish_reason: str + + +@dataclass +class T2SActiveBatch: + request_ids: List[str] + states: List[T2SRequestState] + x: torch.Tensor + x_lens: torch.LongTensor + y_sequences: List[torch.LongTensor] + prefix_lens: torch.LongTensor + xy_pos: torch.Tensor + prefill_attn_mask: torch.Tensor + decode_attn_mask: Optional[torch.Tensor] + k_cache: Optional[List[torch.Tensor]] + v_cache: Optional[List[torch.Tensor]] + step_idx: int + prefill_done: bool + + +def normalize_sentence(text: str, language: str) -> str: + text = text.strip("\n").strip() + if not text: + return text + if text[-1] not in {",", ".", "?", "!", ",", "。", "?", "!", "…", ";", ";", ":"}: + text += "。" if language != "en" else "." + return text + + +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + device = tts.configs.device + prepare_start = time.perf_counter() + _sync_device(device) + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + + _sync_device(device) + text_features_start = time.perf_counter() + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + _sync_device(device) + text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + if phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + + _sync_device(device) + prompt_semantic_start = time.perf_counter() + prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + _sync_device(device) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + + _sync_device(device) + ref_spec_start = time.perf_counter() + spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) + _sync_device(device) + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + + _sync_device(device) + tensorize_start = time.perf_counter() + phones_tensor = torch.LongTensor(phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device) + all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to( + dtype=tts.precision, device=tts.configs.device + ) + _sync_device(device) + tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 + + _sync_device(device) + prepare_profile = { + "prompt_text_features_ms": prompt_text_features_ms, + "text_features_ms": text_features_ms, + "prompt_semantic_ms": prompt_semantic_ms, + "ref_spec_ms": ref_spec_ms, + "tensorize_ms": tensorize_ms, + "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, + "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, + } + return T2SRequestState( + request_id=spec.request_id, + ref_audio_path=spec.ref_audio_path, + prompt_text=prompt_text, + prompt_lang=spec.prompt_lang, + text=text, + text_lang=spec.text_lang, + norm_prompt_text=prompt_norm_text, + norm_text=norm_text, + phones=phones_tensor, + prompt_phones=prompt_phones_tensor, + all_phones=all_phones, + all_bert_features=all_bert_features, + prompt_semantic=prompt_semantic, + refer_spec=(spec_audio, audio_16k), + raw_audio=raw_audio, + raw_sr=int(raw_sr), + top_k=spec.top_k, + top_p=spec.top_p, + temperature=spec.temperature, + repetition_penalty=spec.repetition_penalty, + early_stop_num=spec.early_stop_num, + ready_step=spec.ready_step, + prepare_profile=prepare_profile, + ) + + +def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: + if hidden.shape[0] >= target_len: + return hidden + return F.pad(hidden, (0, 0, target_len - hidden.shape[0], 0), value=0) + + +def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: torch.device) -> None: + required_len = max_position + 1 + if model.ar_audio_position.pe is not None and model.ar_audio_position.pe.size(1) >= required_len: + if model.ar_audio_position.pe.dtype != dtype or model.ar_audio_position.pe.device != device: + model.ar_audio_position.pe = model.ar_audio_position.pe.to(dtype=dtype, device=device) + return + model.ar_audio_position.extend_pe( + torch.zeros(1, required_len, model.ar_audio_position.embedding_dim, device=device, dtype=dtype) + ) + + +def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: + x_items: List[torch.Tensor] = [] + y_pos_items: List[torch.Tensor] = [] + x_lens: List[int] = [] + prefix_lens: List[int] = [] + y_sequences: List[torch.LongTensor] = [] + + for state in states: + text_emb = model.ar_text_embedding(state.all_phones.unsqueeze(0)) + bert_proj = model.bert_proj(state.all_bert_features.transpose(0, 1).unsqueeze(0)) + x_pos = model.ar_text_position(text_emb + bert_proj).squeeze(0) + y_emb = model.ar_audio_embedding(state.prompt_semantic.unsqueeze(0)) + y_pos = model.ar_audio_position(y_emb).squeeze(0) + x_items.append(x_pos) + y_pos_items.append(y_pos) + x_lens.append(x_pos.shape[0]) + prefix_lens.append(y_pos.shape[0]) + y_sequences.append(state.prompt_semantic.clone()) + + max_x_len = max(x_lens) + max_prefix_len = max(prefix_lens) + x_batch = torch.stack([_left_pad_hidden(item, max_x_len) for item in x_items], dim=0) + y_pos_batch = torch.stack([_left_pad_hidden(item, max_prefix_len) for item in y_pos_items], dim=0) + xy_pos = torch.cat([x_batch, y_pos_batch], dim=1) + + device = x_batch.device + x_lens_tensor = torch.LongTensor(x_lens).to(device) + prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) + batch_size = len(states) + src_len = max_x_len + max_prefix_len + + x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) + y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) + padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1) + x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) + y_mask = F.pad( + torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), + (max_x_len, 0), + value=False, + ) + causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1) + padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1) + attn_mask = causal_mask.logical_or(padding_mask) + attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool() + + return T2SActiveBatch( + request_ids=[state.request_id for state in states], + states=list(states), + x=x_batch, + x_lens=x_lens_tensor, + y_sequences=y_sequences, + prefix_lens=prefix_lens_tensor, + xy_pos=xy_pos, + prefill_attn_mask=attn_mask, + decode_attn_mask=None, + k_cache=None, + v_cache=None, + step_idx=0, + prefill_done=False, + ) + + +def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> torch.Tensor: + last_tokens = torch.stack([seq[-1:] for seq in y_sequences], dim=0) + y_emb = model.ar_audio_embedding(last_tokens) + position_ids = torch.LongTensor([int(seq.shape[0] - 1) for seq in y_sequences]).to(y_emb.device) + _ensure_audio_pe(model, int(position_ids.max().item()), y_emb.dtype, y_emb.device) + pos_emb = model.ar_audio_position.pe[0].index_select(0, position_ids).unsqueeze(1) + return y_emb * model.ar_audio_position.x_scale + model.ar_audio_position.alpha * pos_emb.to( + dtype=y_emb.dtype, device=y_emb.device + ) + + +def _sample_per_request( + model: Any, + active_batch: T2SActiveBatch, + logits: torch.Tensor, + max_steps: int, +) -> Tuple[List[T2SFinishedItem], List[int], List[torch.LongTensor]]: + finished_items: List[T2SFinishedItem] = [] + keep_indices: List[int] = [] + updated_sequences: List[torch.LongTensor] = [] + + step_idx = active_batch.step_idx + for batch_index, state in enumerate(active_batch.states): + logits_i = logits[batch_index : batch_index + 1].clone() + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: + finish_reason = "early_stop" + elif step_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=step_idx, + finish_reason=finish_reason, + ) + ) + else: + keep_indices.append(batch_index) + updated_sequences.append(new_history) + + return finished_items, keep_indices, updated_sequences + + +def decode_one_step( + model: Any, + active_batch: T2SActiveBatch, + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not active_batch.prefill_done: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( + active_batch.xy_pos, active_batch.prefill_attn_mask, None + ) + active_batch.decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + active_batch.prefill_done = True + else: + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( + active_batch.xy_pos, + active_batch.k_cache, + active_batch.v_cache, + active_batch.decode_attn_mask, + ) + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False) + + logits = model.ar_predict_layer(xy_dec[:, -1]) + if active_batch.step_idx < 11: + logits = logits[:, :-1] + + finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) + if len(keep_indices) == 0: + return None, finished_items + + device = logits.device + keep_tensor = torch.LongTensor(keep_indices).to(device) + active_batch.request_ids = [active_batch.request_ids[i] for i in keep_indices] + active_batch.states = [active_batch.states[i] for i in keep_indices] + active_batch.y_sequences = updated_sequences + active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if active_batch.k_cache is not None and active_batch.v_cache is not None: + for cache_index in range(len(active_batch.k_cache)): + active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) + active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + active_batch.step_idx += 1 + return active_batch, finished_items + + +def run_scheduler_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + return run_scheduler_continuous(model, states, max_steps=max_steps) + + +def _pad_cache_left(cache: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - cache.shape[1] + if pad_len <= 0: + return cache + return F.pad(cache, (0, 0, pad_len, 0), value=0) + + +def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - mask.shape[-1] + if pad_len <= 0: + return mask + return F.pad(mask, (pad_len, 0), value=True) + + +def run_prefill_step( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not states: + return [], [] + + active_batch = build_prefill_batch(model, states) + xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) + decode_attn_mask = F.pad( + active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), + (0, 1), + value=False, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + running_requests: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, state in enumerate(states): + logits_i = logits[batch_index : batch_index + 1].clone() + if 0 < 11: + logits_i = logits_i[:, :-1] + current_history = active_batch.y_sequences[batch_index] + sampled = sample( + logits_i, + current_history.unsqueeze(0), + top_k=state.top_k, + top_p=state.top_p, + repetition_penalty=state.repetition_penalty, + temperature=state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + prefix_len = int(active_batch.prefix_lens[batch_index].item()) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - prefix_len) > state.early_stop_num: + finish_reason = "early_stop" + elif 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=0, + finish_reason=finish_reason, + ) + ) + continue + + real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len + request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + + running_requests.append( + T2SRunningRequest( + state=state, + y_sequence=new_history, + prefix_len=prefix_len, + decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=1, + ) + ) + + return running_requests, finished_items + + +def _build_decode_batch_from_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]: + xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) + max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) + max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests) + num_layers = len(running_requests[0].k_cache) + + batched_k_cache: List[torch.Tensor] = [] + batched_v_cache: List[torch.Tensor] = [] + for layer_index in range(num_layers): + batched_k_cache.append( + torch.cat([_pad_cache_left(item.k_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + batched_v_cache.append( + torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests], + dim=0, + ) + return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask + + +def run_decode_step_for_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not running_requests: + return [], [] + + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = _build_decode_batch_from_running( + model, running_requests + ) + xy_dec, next_k_cache, next_v_cache = model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + + next_running: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, running_request in enumerate(running_requests): + current_idx = running_request.step_idx + logits_i = logits[batch_index : batch_index + 1].clone() + if current_idx < 11: + logits_i = logits_i[:, :-1] + sampled = sample( + logits_i, + running_request.y_sequence.unsqueeze(0), + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + repetition_penalty=running_request.state.repetition_penalty, + temperature=running_request.state.temperature, + )[0] + sampled_token = int(sampled[0, 0].item()) + argmax_token = int(torch.argmax(logits_i[0], dim=-1).item()) + new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if running_request.state.early_stop_num != -1 and (new_history.shape[0] - running_request.prefix_len) > running_request.state.early_stop_num: + finish_reason = "early_stop" + elif current_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=running_request.state.request_id, + semantic_tokens=new_history[:-1].clone(), + finish_idx=current_idx, + finish_reason=finish_reason, + ) + ) + continue + + real_next_kv_len = running_request.k_cache[0].shape[1] + 1 + request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + next_running.append( + T2SRunningRequest( + state=running_request.state, + y_sequence=new_history, + prefix_len=running_request.prefix_len, + decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False), + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=current_idx + 1, + ) + ) + + return next_running, finished_items + + +def run_scheduler_continuous( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) + running_requests: List[T2SRunningRequest] = [] + finished: List[T2SFinishedItem] = [] + current_tick = 0 + + while pending or running_requests: + admitted: List[T2SRequestState] = [] + while pending and pending[0].ready_step <= current_tick: + admitted.append(pending.pop(0)) + + admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps) + finished.extend(admitted_finished) + + if running_requests: + running_requests, step_finished = run_decode_step_for_running( + model, + running_requests, + max_steps=max_steps, + ) + finished.extend(step_finished) + + running_requests.extend(admitted_running) + + if not running_requests and pending: + current_tick = max(current_tick + 1, pending[0].ready_step) + continue + + current_tick += 1 + + finished.sort(key=lambda item: item.request_id) + return finished diff --git a/api_v3.py b/api_v3.py new file mode 100644 index 00000000..9d250119 --- /dev/null +++ b/api_v3.py @@ -0,0 +1,1170 @@ +""" +# WebAPI文档 + +` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` + +## 执行参数: + `-a` - `绑定地址, 默认"127.0.0.1"` + `-p` - `绑定端口, 默认9880` + `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"` + +## 调用: + +### 推理 + +endpoint: `/tts` +GET: +``` +http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true +``` + +POST: +```json +{ + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: +``` +http://127.0.0.1:9880/control?command=restart +``` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + + +### 切换GPT模型 + +endpoint: `/set_gpt_weights` + +GET: +``` +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +``` +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 切换Sovits模型 + +endpoint: `/set_sovits_weights` + +GET: +``` +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +``` + +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + +""" + +import asyncio +import os +import sys +import time +import traceback +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Generator, List, Union + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import argparse +import subprocess +import wave +import signal +import numpy as np +import soundfile as sf +import torch +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from io import BytesIO +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + SchedulerRequestSpec, + T2SFinishedItem, + T2SRunningRequest, + T2SRequestState, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, + run_scheduler_continuous, +) +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from pydantic import BaseModel +import threading + +# print(sys.path) +i18n = I18nAuto() +cut_method_names = get_cut_method_names() + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") +parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") +args = parser.parse_args() +config_path = args.tts_config +# device = args.device +port = args.port +host = args.bind_addr +argv = sys.argv + +if config_path in [None, ""]: + config_path = "GPT-SoVITS/configs/tts_infer.yaml" + +tts_config = TTS_Config(config_path) +print(tts_config) +tts_pipeline = TTS(tts_config) + +APP = FastAPI() + + +class TTS_Request(BaseModel): + text: str = None + text_lang: str = None + ref_audio_path: str = None + aux_ref_audio_paths: list = None + prompt_lang: str = None + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = True + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: Union[bool, int] = False + parallel_infer: bool = True + repetition_penalty: float = 1.35 + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + + +class Scheduler_Debug_Request_Item(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + + +class Scheduler_Debug_Request(BaseModel): + requests: List[Scheduler_Debug_Request_Item] + max_steps: int = 1500 + seed: int = -1 + + +class Scheduler_Submit_Request(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + speed_factor: float = 1.0 + sample_steps: int = 32 + media_type: str = "wav" + timeout_sec: float = 30.0 + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + prepare_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + decode_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + + +class SchedulerDebugWorker: + def __init__(self, tts: TTS, max_steps: int = 1500, micro_batch_wait_ms: int = 5): + self.tts = tts + self.max_steps = max_steps + self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 + self.prepare_lock = threading.Lock() + self.condition = threading.Condition() + self.pending_jobs: List[SchedulerPendingJob] = [] + self.running_requests: List[T2SRunningRequest] = [] + self.job_map: dict[str, SchedulerPendingJob] = {} + self.total_finished = 0 + self.total_submitted = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="t2s-scheduler-debug-worker", daemon=True) + self.worker_thread.start() + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: + with self.prepare_lock: + return prepare_request_state(self.tts, spec) + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + prepare_ms: float, + prepare_wall_ms: float, + ) -> SchedulerPendingJob: + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + prepare_ms=float(prepare_ms), + prepare_wall_ms=float(prepare_wall_ms), + ) + with self.condition: + self.pending_jobs.append(job) + self.job_map[job.request_id] = job + self.total_submitted += 1 + self.condition.notify_all() + return job + + def _mark_prefill_started(self, jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None and tracked_job.first_schedule_time is None: + tracked_job.first_schedule_time = started_at + + def _add_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for job in jobs: + tracked_job = self.job_map.get(job.request_id) + if tracked_job is not None: + tracked_job.prefill_ms += elapsed_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + elapsed_ms = elapsed_s * 1000.0 + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is not None: + job.decode_ms += elapsed_ms + job.decode_steps += 1 + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + semantic_tokens = item.semantic_tokens.unsqueeze(0).unsqueeze(0).to(self.tts.configs.device) + phones = job.state.phones.unsqueeze(0).to(self.tts.configs.device) + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=job.state.prompt_semantic, + prompt_phones=job.state.prompt_phones, + refer_spec=job.state.refer_spec, + raw_audio=job.state.raw_audio, + raw_sr=job.state.raw_sr, + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + def get_state(self) -> dict: + with self.condition: + return { + "pending_jobs": len(self.pending_jobs), + "running_requests": len(self.running_requests), + "tracked_jobs": len(self.job_map), + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "max_steps": self.max_steps, + "micro_batch_wait_ms": int(self.micro_batch_wait_s * 1000), + } + + def _finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + jobs_to_finalize: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for item in items: + job = self.job_map.get(item.request_id) + if job is not None: + jobs_to_finalize.append((job, item)) + + for job, item in jobs_to_finalize: + try: + self._sync_device() + synth_start = time.perf_counter() + sample_rate, audio_data = self._synthesize_finished_audio(job, item) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + except Exception as exc: + self._finalize_error([item.request_id], str(exc)) + continue + + finished_at = time.perf_counter() + with self.condition: + if self.job_map.get(item.request_id) is not job: + continue + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + job.synth_ms += synth_ms + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + prepare_profile = dict(job.state.prepare_profile) + job.result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "prepare_ms": job.prepare_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "decode_ms": job.decode_ms, + "synth_ms": job.synth_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.done_event.set() + self.job_map.pop(item.request_id, None) + self.total_finished += 1 + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_map.get(request_id) + if job is None: + continue + job.error = error + job.done_event.set() + self.job_map.pop(request_id, None) + self.total_finished += 1 + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and not self.running_requests: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def _run_loop(self) -> None: + while True: + wait_for_batch = len(self.running_requests) == 0 + pending_jobs = self._take_pending_snapshot(wait_for_batch=wait_for_batch) + + if pending_jobs: + try: + self._sync_device() + prefill_start = time.perf_counter() + self._mark_prefill_started(pending_jobs, prefill_start) + admitted_running, admitted_finished = run_prefill_step( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + self._add_prefill_time(pending_jobs, time.perf_counter() - prefill_start) + self._finalize_finished(admitted_finished) + self.running_requests.extend(admitted_running) + except Exception as exc: + self._finalize_error([job.request_id for job in pending_jobs], str(exc)) + + if self.running_requests: + try: + active_request_ids = [item.state.request_id for item in self.running_requests] + self._sync_device() + decode_start = time.perf_counter() + self.running_requests, step_finished = run_decode_step_for_running( + self.tts.t2s_model.model, + self.running_requests, + max_steps=self.max_steps, + ) + self._sync_device() + self._add_decode_time(active_request_ids, time.perf_counter() - decode_start) + self._finalize_finished(step_finished) + except Exception as exc: + self._finalize_error(active_request_ids, str(exc)) + self.running_requests = [] + continue + + if not pending_jobs: + time.sleep(self.micro_batch_wait_s) + + +scheduler_debug_worker = SchedulerDebugWorker(tts_pipeline) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + "192k", # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + # This will create a wave header then append the frame input + # It should be first on a streaming wav file + # Other frames better should not have it (else you will hear some artifacts each chunk start) + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + return wav_buf.read() + + +def handle_control(command: str): + if command == "restart": + os.execl(sys.executable, sys.executable, *argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def check_params(req: dict): + text: str = req.get("text", "") + text_lang: str = req.get("text_lang", "") + ref_audio_path: str = req.get("ref_audio_path", "") + streaming_mode: bool = req.get("streaming_mode", False) + media_type: str = req.get("media_type", "wav") + prompt_lang: str = req.get("prompt_lang", "") + text_split_method: str = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if text_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text_lang is required"}) + elif text_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, + ) + if prompt_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) + elif prompt_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, + ) + if media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) + # elif media_type == "ogg" and not streaming_mode: + # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + + if text_split_method not in cut_method_names: + return JSONResponse( + status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} + ) + + return None + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def build_scheduler_request_specs(request_items: List[Scheduler_Debug_Request_Item]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(request_items): + payload = item.dict() + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request[{index}] 参数非法: {detail}") + specs.append( + SchedulerRequestSpec( + request_id=payload["request_id"] or f"req_{index:03d}", + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=int(payload["ready_step"]), + ) + ) + return specs + + +def summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def prepare_scheduler_states_batch(specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return [scheduler_debug_worker.prepare_state(spec) for spec in specs] + + +def build_scheduler_submit_spec(request: Scheduler_Submit_Request) -> SchedulerRequestSpec: + payload = request.dict() + request_id = payload["request_id"] or f"job_{uuid.uuid4().hex[:12]}" + req = { + "text": payload["text"], + "text_lang": payload["text_lang"].lower(), + "ref_audio_path": payload["ref_audio_path"], + "aux_ref_audio_paths": None, + "prompt_text": payload["prompt_text"], + "prompt_lang": payload["prompt_lang"].lower(), + "top_k": payload["top_k"], + "top_p": payload["top_p"], + "temperature": payload["temperature"], + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": float(payload["speed_factor"]), + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": payload["media_type"], + "streaming_mode": False, + "parallel_infer": False, + "repetition_penalty": payload["repetition_penalty"], + "sample_steps": int(payload["sample_steps"]), + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + } + check_res = check_params(req) + if check_res is not None: + detail = check_res.body.decode("utf-8") if hasattr(check_res, "body") else str(check_res) + raise ValueError(f"request 参数非法: {detail}") + return SchedulerRequestSpec( + request_id=request_id, + ref_audio_path=Path(payload["ref_audio_path"]), + prompt_text=payload["prompt_text"], + prompt_lang=payload["prompt_lang"].lower(), + text=payload["text"], + text_lang=payload["text_lang"].lower(), + top_k=int(payload["top_k"]), + top_p=float(payload["top_p"]), + temperature=float(payload["temperature"]), + repetition_penalty=float(payload["repetition_penalty"]), + early_stop_num=int(payload["early_stop_num"]), + ready_step=0, + ) + + +async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): + try: + set_scheduler_seed(request.seed) + specs = build_scheduler_request_specs(request.requests) + states = await asyncio.to_thread(prepare_scheduler_states_batch, specs) + finished = run_scheduler_continuous(tts_pipeline.t2s_model.model, states, max_steps=int(request.max_steps)) + return JSONResponse( + status_code=200, + content={ + "message": "success", + "request_count": len(states), + "max_steps": int(request.max_steps), + "requests": summarize_scheduler_states(states), + "finished": summarize_scheduler_finished(finished), + }, + ) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler debug failed", "Exception": str(e)}, + ) + + +async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): + try: + request_start = time.perf_counter() + spec = build_scheduler_submit_spec(request) + prepare_start = time.perf_counter() + state = await asyncio.to_thread(scheduler_debug_worker.prepare_state, spec) + prepare_wall_ms = (time.perf_counter() - prepare_start) * 1000.0 + prepare_ms = float(state.prepare_profile.get("total_ms", prepare_wall_ms)) + job = scheduler_debug_worker.submit( + state, + speed_factor=float(request.speed_factor), + sample_steps=int(request.sample_steps), + media_type=request.media_type, + prepare_ms=prepare_ms, + prepare_wall_ms=prepare_wall_ms, + ) + timeout_ok = await asyncio.to_thread(job.done_event.wait, float(request.timeout_sec)) + if not timeout_ok: + return JSONResponse( + status_code=202, + content={ + "message": "queued", + "request_id": job.request_id, + "timings": { + "prepare_ms": prepare_ms, + "prepare_wall_ms": prepare_wall_ms, + "request_elapsed_ms": max(0.0, (time.perf_counter() - request_start) * 1000.0), + }, + "worker_state": scheduler_debug_worker.get_state(), + }, + ) + if job.error is not None: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "request_id": job.request_id, "Exception": job.error}, + ) + if job.audio_data is None or job.sample_rate is None: + return JSONResponse( + status_code=500, + content={ + "message": "scheduler submit failed", + "request_id": job.request_id, + "Exception": "job finished without audio payload", + }, + ) + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_ms = (time.perf_counter() - pack_start) * 1000.0 + job.pack_ms = pack_ms + request_total_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + headers = { + "X-Request-Id": job.request_id, + "X-Semantic-Len": str(job.result["semantic_len"]) if job.result is not None else "0", + "X-Finish-Reason": job.result["finish_reason"] if job.result is not None else "unknown", + "X-Queue-Wait-Ms": ( + f"{float(job.result['queue_wait_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Prepare-Ms": f"{prepare_ms:.3f}", + "X-Prepare-Wall-Ms": f"{prepare_wall_ms:.3f}", + "X-Prefill-Ms": f"{float(job.result['prefill_ms']):.3f}" if job.result is not None else "0.000", + "X-Decode-Ms": f"{float(job.result['decode_ms']):.3f}" if job.result is not None else "0.000", + "X-Synth-Ms": f"{float(job.result['synth_ms']):.3f}" if job.result is not None else "0.000", + "X-Pack-Ms": f"{pack_ms:.3f}", + "X-Worker-Total-Ms": ( + f"{float(job.result['worker_total_ms']):.3f}" if job.result is not None else "0.000" + ), + "X-Request-Total-Ms": f"{request_total_ms:.3f}", + "X-Decode-Steps": str(job.result["decode_steps"]) if job.result is not None else "0", + } + if job.result is not None: + prepare_profile = job.result.get("prepare_profile", {}) + headers.update( + { + "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", + "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", + "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + } + ) + return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "Exception": str(e)}, + ) + + +async def tts_handle(req: dict): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + } + returns: + StreamingResponse: audio stream response. + """ + + streaming_mode = req.get("streaming_mode", False) + return_fragment = req.get("return_fragment", False) + media_type = req.get("media_type", "wav") + + check_res = check_params(req) + if check_res is not None: + return check_res + + if streaming_mode == 0: + streaming_mode = False + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 1: + streaming_mode = False + return_fragment = True + fixed_length_chunk = False + elif streaming_mode == 2: + streaming_mode = True + return_fragment = False + fixed_length_chunk = False + elif streaming_mode == 3: + streaming_mode = True + return_fragment = False + fixed_length_chunk = True + + else: + return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) + + req["streaming_mode"] = streaming_mode + req["return_fragment"] = return_fragment + req["fixed_length_chunk"] = fixed_length_chunk + + print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") + + streaming_mode = streaming_mode or return_fragment + + + try: + tts_generator = tts_pipeline.run(req) + + if streaming_mode: + + def streaming_generator(tts_generator: Generator, media_type: str): + if_frist_chunk = True + for sr, chunk in tts_generator: + if if_frist_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + media_type = "raw" + if_frist_chunk = False + yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + + # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" + return StreamingResponse( + streaming_generator( + tts_generator, + media_type, + ), + media_type=f"audio/{media_type}", + ) + + else: + sr, audio_data = next(tts_generator) + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_data, media_type=f"audio/{media_type}") + except Exception as e: + return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) + + +@APP.get("/control") +async def control(command: str = None): + if command is None: + return JSONResponse(status_code=400, content={"message": "command is required"}) + handle_control(command) + + +@APP.get("/tts") +async def tts_get_endpoint( + text: str = None, + text_lang: str = None, + ref_audio_path: str = None, + aux_ref_audio_paths: list = None, + prompt_lang: str = None, + prompt_text: str = "", + top_k: int = 15, + top_p: float = 1, + temperature: float = 1, + text_split_method: str = "cut5", + batch_size: int = 1, + batch_threshold: float = 0.75, + split_bucket: bool = True, + speed_factor: float = 1.0, + fragment_interval: float = 0.3, + seed: int = -1, + media_type: str = "wav", + parallel_infer: bool = True, + repetition_penalty: float = 1.35, + sample_steps: int = 32, + super_sampling: bool = False, + streaming_mode: Union[bool, int] = False, + overlap_length: int = 2, + min_chunk_length: int = 16, +): + req = { + "text": text, + "text_lang": text_lang.lower(), + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": aux_ref_audio_paths, + "prompt_text": prompt_text, + "prompt_lang": prompt_lang.lower(), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size": int(batch_size), + "batch_threshold": float(batch_threshold), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "fragment_interval": fragment_interval, + "seed": seed, + "media_type": media_type, + "streaming_mode": streaming_mode, + "parallel_infer": parallel_infer, + "repetition_penalty": float(repetition_penalty), + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "min_chunk_length": int(min_chunk_length), + } + return await tts_handle(req) + + +@APP.post("/tts") +async def tts_post_endpoint(request: TTS_Request): + req = request.dict() + return await tts_handle(req) + + +@APP.post("/tts_scheduler_debug") +async def tts_scheduler_debug_endpoint(request: Scheduler_Debug_Request): + return await tts_scheduler_debug_handle(request) + + +@APP.post("/tts_scheduler_submit") +async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request): + return await tts_scheduler_submit_handle(request) + + +@APP.get("/tts_scheduler_state") +async def tts_scheduler_state_endpoint(): + return JSONResponse(status_code=200, content={"message": "success", "worker_state": scheduler_debug_worker.get_state()}) + + +@APP.get("/set_refer_audio") +async def set_refer_aduio(refer_audio_path: str = None): + try: + tts_pipeline.set_ref_audio(refer_audio_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +# @APP.post("/set_refer_audio") +# async def set_refer_aduio_post(audio_file: UploadFile = File(...)): +# try: +# # 检查文件类型,确保是音频文件 +# if not audio_file.content_type.startswith("audio/"): +# return JSONResponse(status_code=400, content={"message": "file type is not supported"}) + +# os.makedirs("uploaded_audio", exist_ok=True) +# save_path = os.path.join("uploaded_audio", audio_file.filename) +# # 保存音频文件到服务器上的一个目录 +# with open(save_path , "wb") as buffer: +# buffer.write(await audio_file.read()) + +# tts_pipeline.set_ref_audio(save_path) +# except Exception as e: +# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)}) +# return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_gpt_weights") +async def set_gpt_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) + tts_pipeline.init_t2s_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) + + return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_sovits_weights") +async def set_sovits_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) + tts_pipeline.init_vits_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +if __name__ == "__main__": + try: + if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None + uvicorn.run(app=APP, host=host, port=port, workers=1) + except Exception: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) diff --git a/tools/t2s_scheduler_prototype.py b/tools/t2s_scheduler_prototype.py new file mode 100644 index 00000000..cd4b9c6d --- /dev/null +++ b/tools/t2s_scheduler_prototype.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SFinishedItem, + T2SRequestState, + prepare_request_state, + run_scheduler_continuous, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-lang", type=str, default="en") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run") + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_pipeline(config_path: Path): + try: + from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。" + ) from exc + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8") + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8") + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + use_cuda = str(tts.configs.device).startswith("cuda") + set_seed(args.seed, use_cuda) + + request_specs = load_request_specs(args) + states = [prepare_request_state(tts, spec) for spec in request_specs] + finished = run_scheduler_continuous(model, states, max_steps=args.max_steps) + + summary = { + "request_count": len(states), + "max_steps": args.max_steps, + "requests": summarise_requests(states), + "finished": summarise_finished(finished), + } + output_path = args.output_dir / "scheduler_prototype_summary.json" + output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(summary, ensure_ascii=False, indent=2)) + print(f"[saved] {output_path}") + + +if __name__ == "__main__": + try: + main() + except ModuleNotFoundError as exc: + print(f"[error] {exc}") + raise SystemExit(1) from None From d245eb169c57c50ea55444bdf104ded247058872 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 01:42:04 +0800 Subject: [PATCH 5/6] Refactor T2S scheduler and inference handling to improve attention mask management and memory tracking. Update T2SRunningRequest and T2SActiveBatch classes to include optional key padding masks. Introduce new benchmarking tools for API performance and memory usage analysis, enhancing overall system efficiency. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 1 + GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 82 +- tools/bench_api_v3_scheduler_submit.py | 250 ++++++ tools/t2s_memory_breakdown.py | 887 +++++++++++++++++++++ 4 files changed, 1192 insertions(+), 28 deletions(-) create mode 100644 tools/bench_api_v3_scheduler_submit.py create mode 100644 tools/t2s_memory_breakdown.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index bd4953df..2fd0df35 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1781,6 +1781,7 @@ class TTS: return audio + @torch.inference_mode() def synthesize_audio_request_local( self, semantic_tokens: torch.Tensor, diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index e94a72c7..c8643991 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -70,7 +70,7 @@ class T2SRunningRequest: state: T2SRequestState y_sequence: torch.LongTensor prefix_len: int - decode_attn_mask: torch.Tensor + decode_attn_mask: Optional[torch.Tensor] k_cache: List[torch.Tensor] v_cache: List[torch.Tensor] step_idx: int @@ -93,6 +93,7 @@ class T2SActiveBatch: y_sequences: List[torch.LongTensor] prefix_lens: torch.LongTensor xy_pos: torch.Tensor + key_padding_mask: torch.Tensor prefill_attn_mask: torch.Tensor decode_attn_mask: Optional[torch.Tensor] k_cache: Optional[List[torch.Tensor]] @@ -110,6 +111,7 @@ def normalize_sentence(text: str, language: str) -> str: return text +@torch.inference_mode() def prepare_request_state( tts: Any, spec: SchedulerRequestSpec, @@ -212,6 +214,7 @@ def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: ) +@torch.inference_mode() def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: x_items: List[torch.Tensor] = [] y_pos_items: List[torch.Tensor] = [] @@ -240,22 +243,19 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct device = x_batch.device x_lens_tensor = torch.LongTensor(x_lens).to(device) prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) - batch_size = len(states) src_len = max_x_len + max_prefix_len x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) - padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1) + key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool() x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) y_mask = F.pad( torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), (max_x_len, 0), value=False, ) - causal_mask = torch.cat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(batch_size, 1, 1) - padding_mask = padding_mask.view(batch_size, 1, src_len).repeat(1, src_len, 1) - attn_mask = causal_mask.logical_or(padding_mask) - attn_mask = attn_mask.unsqueeze(1).expand(-1, model.num_head, -1, -1).bool() + causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0) + attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1) return T2SActiveBatch( request_ids=[state.request_id for state in states], @@ -265,6 +265,7 @@ def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SAct y_sequences=y_sequences, prefix_lens=prefix_lens_tensor, xy_pos=xy_pos, + key_padding_mask=key_padding_mask, prefill_attn_mask=attn_mask, decode_attn_mask=None, k_cache=None, @@ -322,10 +323,11 @@ def _sample_per_request( finish_reason = "eos_argmax" if finish_reason is not None: + prefix_len = int(active_batch.prefix_lens[batch_index].item()) finished_items.append( T2SFinishedItem( request_id=state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[prefix_len:-1].clone(), finish_idx=step_idx, finish_reason=finish_reason, ) @@ -346,11 +348,7 @@ def decode_one_step( xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( active_batch.xy_pos, active_batch.prefill_attn_mask, None ) - active_batch.decode_attn_mask = F.pad( - active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), - (0, 1), - value=False, - ) + active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) active_batch.prefill_done = True else: xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( @@ -411,6 +409,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: return F.pad(mask, (pad_len, 0), value=True) +def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + if running_request.decode_attn_mask is not None: + return running_request.decode_attn_mask + current_mask_len = running_request.k_cache[0].shape[1] + 1 + return torch.zeros( + (1, 1, 1, current_mask_len), + dtype=torch.bool, + device=running_request.k_cache[0].device, + ) + + +@torch.inference_mode() def run_prefill_step( model: Any, states: Sequence[T2SRequestState], @@ -421,11 +431,9 @@ def run_prefill_step( active_batch = build_prefill_batch(model, states) xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) - decode_attn_mask = F.pad( - active_batch.prefill_attn_mask[:, :, -1].unsqueeze(-2), - (0, 1), - value=False, - ) + decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if len(states) == 1 and not decode_attn_mask.any().item(): + decode_attn_mask = None logits = model.ar_predict_layer(xy_dec[:, -1]) running_requests: List[T2SRunningRequest] = [] @@ -463,7 +471,7 @@ def run_prefill_step( finished_items.append( T2SFinishedItem( request_id=state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[prefix_len:-1].clone(), finish_idx=0, finish_reason=finish_reason, ) @@ -479,7 +487,11 @@ def run_prefill_step( state=state, y_sequence=new_history, prefix_len=prefix_len, - decode_attn_mask=decode_attn_mask[batch_index : batch_index + 1].clone(), + decode_attn_mask=( + None + if decode_attn_mask is None + else decode_attn_mask[batch_index : batch_index + 1].clone() + ), k_cache=request_k_cache, v_cache=request_v_cache, step_idx=1, @@ -492,10 +504,9 @@ def run_prefill_step( def _build_decode_batch_from_running( model: Any, running_requests: Sequence[T2SRunningRequest], -) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], torch.Tensor]: +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]: xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) - max_mask_len = max(item.decode_attn_mask.shape[-1] for item in running_requests) num_layers = len(running_requests[0].k_cache) batched_k_cache: List[torch.Tensor] = [] @@ -508,13 +519,19 @@ def _build_decode_batch_from_running( torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) ) - batched_decode_attn_mask = torch.cat( - [_pad_decode_mask_left(item.decode_attn_mask, max_mask_len) for item in running_requests], - dim=0, - ) + if all(item.decode_attn_mask is None for item in running_requests): + batched_decode_attn_mask = None + else: + materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests] + max_mask_len = max(mask.shape[-1] for mask in materialized_masks) + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks], + dim=0, + ) return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask +@torch.inference_mode() def run_decode_step_for_running( model: Any, running_requests: Sequence[T2SRunningRequest], @@ -568,7 +585,7 @@ def run_decode_step_for_running( finished_items.append( T2SFinishedItem( request_id=running_request.state.request_id, - semantic_tokens=new_history[:-1].clone(), + semantic_tokens=new_history[running_request.prefix_len:-1].clone(), finish_idx=current_idx, finish_reason=finish_reason, ) @@ -578,12 +595,20 @@ def run_decode_step_for_running( real_next_kv_len = running_request.k_cache[0].shape[1] + 1 request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + if batched_decode_attn_mask is None: + next_decode_attn_mask = None + else: + current_decode_mask_len = running_request.k_cache[0].shape[1] + 1 + current_decode_attn_mask = batched_decode_attn_mask[ + batch_index : batch_index + 1, :, :, -current_decode_mask_len: + ] + next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) next_running.append( T2SRunningRequest( state=running_request.state, y_sequence=new_history, prefix_len=running_request.prefix_len, - decode_attn_mask=F.pad(running_request.decode_attn_mask, (0, 1), value=False), + decode_attn_mask=next_decode_attn_mask, k_cache=request_k_cache, v_cache=request_v_cache, step_idx=current_idx + 1, @@ -593,6 +618,7 @@ def run_decode_step_for_running( return next_running, finished_items +@torch.inference_mode() def run_scheduler_continuous( model: Any, states: Sequence[T2SRequestState], diff --git a/tools/bench_api_v3_scheduler_submit.py b/tools/bench_api_v3_scheduler_submit.py new file mode 100644 index 00000000..c16468e1 --- /dev/null +++ b/tools/bench_api_v3_scheduler_submit.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import asyncio +import json +import subprocess +import threading +import time +import wave +from pathlib import Path +from typing import Any, Dict, List, Optional + +import httpx + +ROOT_DIR = Path(__file__).resolve().parents[1] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.") + parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880") + parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit") + parser.add_argument("--concurrency", type=int, required=True) + parser.add_argument("--timeout-sec", type=float, default=120.0) + parser.add_argument("--server-pid", type=int, default=None) + parser.add_argument("--poll-interval-sec", type=float, default=0.1) + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--media-type", type=str, default="wav") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--sample-steps", type=int, default=32) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench") + return parser.parse_args() + + +def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]: + wav_paths_all = sorted(args.wav_dir.glob("*.wav")) + wav_paths: List[Path] = [] + for wav_path in wav_paths_all: + with wave.open(str(wav_path), "rb") as handle: + duration = handle.getnframes() / float(handle.getframerate()) + if 3.0 <= duration <= 10.0: + wav_paths.append(wav_path) + if not wav_paths: + raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}") + text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if not text_lines: + raise ValueError(f"没有找到有效文本行: {args.text_file}") + + requests: List[Dict[str, Any]] = [] + for index in range(args.concurrency): + wav_path = wav_paths[index % len(wav_paths)] + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"缺少参考文本: {lab_path}") + requests.append( + { + "request_id": f"bench_{args.concurrency:03d}_{index:03d}", + "text": text_lines[index % len(text_lines)], + "text_lang": args.text_lang, + "ref_audio_path": str(wav_path), + "prompt_lang": args.prompt_lang, + "prompt_text": lab_path.read_text(encoding="utf-8").strip(), + "top_k": int(args.top_k), + "top_p": float(args.top_p), + "temperature": float(args.temperature), + "repetition_penalty": float(args.repetition_penalty), + "sample_steps": int(args.sample_steps), + "media_type": args.media_type, + "timeout_sec": float(args.timeout_sec), + } + ) + return requests + + +class GpuMemoryPoller: + def __init__(self, server_pid: Optional[int], interval_sec: float): + self.server_pid = server_pid + self.interval_sec = interval_sec + self._stop = threading.Event() + self.samples: List[Dict[str, Any]] = [] + self.thread: Optional[threading.Thread] = None + + def _query_memory_mb(self) -> Optional[int]: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-compute-apps=pid,used_gpu_memory", + "--format=csv,noheader,nounits", + ], + check=True, + capture_output=True, + text=True, + ) + except Exception: + return None + total = 0 + found = False + for line in result.stdout.splitlines(): + line = line.strip() + if not line: + continue + parts = [item.strip() for item in line.split(",")] + if len(parts) != 2: + continue + try: + pid = int(parts[0]) + used_mb = int(parts[1]) + except ValueError: + continue + if self.server_pid is None or pid == self.server_pid: + total += used_mb + found = True + if self.server_pid is None: + return total + return total if found else 0 + + def _run(self) -> None: + while not self._stop.is_set(): + used_mb = self._query_memory_mb() + self.samples.append({"ts": time.time(), "used_mb": used_mb}) + self._stop.wait(self.interval_sec) + + def start(self) -> None: + self.thread = threading.Thread(target=self._run, daemon=True) + self.thread.start() + + def stop(self) -> None: + self._stop.set() + if self.thread is not None: + self.thread.join(timeout=2.0) + + def summary(self) -> Dict[str, Any]: + valid = [item for item in self.samples if item["used_mb"] is not None] + peak = max(valid, key=lambda item: item["used_mb"]) if valid else None + first = valid[0] if valid else None + last = valid[-1] if valid else None + return { + "server_pid": self.server_pid, + "sample_count": int(len(self.samples)), + "start_used_mb": None if first is None else int(first["used_mb"]), + "peak_used_mb": None if peak is None else int(peak["used_mb"]), + "peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]), + "end_used_mb": None if last is None else int(last["used_mb"]), + "peak_ts": None if peak is None else float(peak["ts"]), + "samples": self.samples, + } + + +async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]: + started = time.perf_counter() + try: + response = await client.post(url, json=payload) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + item = { + "request_id": payload["request_id"], + "status_code": int(response.status_code), + "elapsed_ms": float(elapsed_ms), + "content_type": response.headers.get("content-type"), + "audio_bytes": int(len(response.content)), + "headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")}, + } + if response.status_code != 200: + try: + item["error_body"] = response.json() + except Exception: + item["error_body"] = response.text + return item + except Exception as exc: + return { + "request_id": payload["request_id"], + "status_code": -1, + "elapsed_ms": float((time.perf_counter() - started) * 1000.0), + "exception": repr(exc), + } + + +async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]: + payloads = load_requests(args) + url = args.base_url.rstrip("/") + args.endpoint + poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec) + + limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency) + timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0) + + started = time.perf_counter() + poller.start() + try: + async with httpx.AsyncClient(limits=limits, timeout=timeout) as client: + results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads]) + finally: + poller.stop() + wall_ms = (time.perf_counter() - started) * 1000.0 + + ok_results = [item for item in results if item["status_code"] == 200] + failed_results = [item for item in results if item["status_code"] != 200] + request_total_ms = [] + worker_total_ms = [] + for item in ok_results: + headers = item.get("headers", {}) + if "x-request-total-ms" in headers: + request_total_ms.append(float(headers["x-request-total-ms"])) + if "x-worker-total-ms" in headers: + worker_total_ms.append(float(headers["x-worker-total-ms"])) + + return { + "concurrency": int(args.concurrency), + "server_pid": args.server_pid, + "request_count": int(len(payloads)), + "wall_ms": float(wall_ms), + "success_count": int(len(ok_results)), + "failure_count": int(len(failed_results)), + "request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None, + "request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None, + "worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None, + "worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None, + "gpu_memory": poller.summary(), + "results": results, + } + + +def main() -> None: + args = parse_args() + output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}" + output_dir.mkdir(parents=True, exist_ok=True) + summary = asyncio.run(run_benchmark(args)) + summary_path = output_dir / "summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps({ + "concurrency": summary["concurrency"], + "success_count": summary["success_count"], + "failure_count": summary["failure_count"], + "wall_ms": summary["wall_ms"], + "gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"], + "request_total_ms_avg": summary["request_total_ms_avg"], + "request_total_ms_max": summary["request_total_ms_max"], + "summary_path": str(summary_path), + }, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/tools/t2s_memory_breakdown.py b/tools/t2s_memory_breakdown.py new file mode 100644 index 00000000..18127953 --- /dev/null +++ b/tools/t2s_memory_breakdown.py @@ -0,0 +1,887 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import gc +import contextlib +import json +import random +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402 +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SRequestState, + T2SRunningRequest, + _build_decode_batch_from_running, + build_prefill_batch, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"]) + parser.add_argument("--auto-count", type=int, default=4) + parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--warmup", action="store_true", default=False) + parser.add_argument("--worker-rounds", type=int, default=1) + parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"]) + parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False) + parser.add_argument( + "--output-dir", + type=Path, + default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1", + ) + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +def bytes_to_mb(num_bytes: int) -> float: + return float(num_bytes) / (1024.0 * 1024.0) + + +def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int: + if tensor is None: + return 0 + return int(tensor.numel() * tensor.element_size()) + + +def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int: + return int(sum(tensor_nbytes(item) for item in items)) + + +def model_nbytes(module: torch.nn.Module) -> int: + total = 0 + for parameter in module.parameters(): + total += tensor_nbytes(parameter) + for buffer in module.buffers(): + total += tensor_nbytes(buffer) + return int(total) + + +def build_module_weight_summary(tts: TTS) -> Dict[str, Any]: + modules = { + "t2s_model": tts.t2s_model, + "t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None, + "vits_model": tts.vits_model, + "bert_model": tts.bert_model, + "cnhuhbert_model": tts.cnhuhbert_model, + "vocoder": tts.vocoder, + "sv_model": tts.sv_model, + } + by_module = {} + total_bytes = 0 + for name, module in modules.items(): + module_bytes = model_nbytes(module) if module is not None else 0 + by_module[name] = { + "bytes": int(module_bytes), + "mb": bytes_to_mb(module_bytes), + } + total_bytes += module_bytes + return { + "by_module": by_module, + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + } + + +def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]: + storages: Dict[int, Dict[str, Any]] = {} + tensor_views: List[Dict[str, Any]] = [] + for obj in gc.get_objects(): + try: + tensor = None + if torch.is_tensor(obj): + tensor = obj + elif hasattr(obj, "data") and torch.is_tensor(obj.data): + tensor = obj.data + if tensor is None or not tensor.is_cuda: + continue + storage = tensor.untyped_storage() + storage_ptr = int(storage.data_ptr()) + if storage_ptr not in storages: + storages[storage_ptr] = { + "storage_ptr": storage_ptr, + "storage_bytes": int(storage.nbytes()), + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + "device": str(tensor.device), + } + tensor_views.append( + { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "bytes": tensor_nbytes(tensor), + "device": str(tensor.device), + } + ) + except Exception: + continue + storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True) + tensor_views.sort(key=lambda item: item["bytes"], reverse=True) + return { + "unique_storage_count": int(len(storage_list)), + "unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)), + "unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)), + "top_storages": storage_list[:top_k], + "top_tensor_views": tensor_views[:top_k], + } + + +def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip() + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count] + if len(wav_paths) < args.auto_count: + raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav") + text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if len(text_lines) < args.auto_count: + raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本") + specs: List[SchedulerRequestSpec] = [] + for index, wav_path in enumerate(wav_paths): + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"找不到参考文本 {lab_path}") + specs.append( + SchedulerRequestSpec( + request_id=f"req_{index:03d}", + ref_audio_path=wav_path, + prompt_text=lab_path.read_text(encoding="utf-8").strip(), + prompt_lang="zh", + text=text_lines[index], + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ) + return specs + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8").strip() + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + if args.scenario == "single": + return build_single_spec(args) + return build_auto_specs(args) + + +def load_pipeline(config_path: Path) -> TTS: + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def cuda_mem_snapshot(device: Any) -> Dict[str, float]: + if not (str(device).startswith("cuda") and torch.cuda.is_available()): + return { + "allocated_mb": 0.0, + "reserved_mb": 0.0, + "max_allocated_mb": 0.0, + "max_reserved_mb": 0.0, + } + _sync_device(device) + return { + "allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)), + "reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)), + "max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)), + "max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)), + } + + +def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]: + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.reset_peak_memory_stats(device) + before = cuda_mem_snapshot(device) + started = time.perf_counter() + result = fn() + _sync_device(device) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + after = cuda_mem_snapshot(device) + after["elapsed_ms"] = float(elapsed_ms) + after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"]) + after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"]) + after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0)) + return result, after + + +class GlobalPeakRecorder: + def __init__(self, device: Any): + self.device = device + self.checkpoints: List[Dict[str, Any]] = [] + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + def record(self, label: str, **extra: Any) -> None: + snapshot = cuda_mem_snapshot(self.device) + snapshot["label"] = label + snapshot.update(extra) + self.checkpoints.append(snapshot) + + def summary(self) -> Dict[str, Any]: + peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None + return { + "peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]), + "peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]), + "peak_label": None if peak is None else peak["label"], + "checkpoints": self.checkpoints, + } + + +def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]: + per_request = [] + total = { + "phones_bytes": 0, + "prompt_phones_bytes": 0, + "all_phones_bytes": 0, + "all_bert_features_bytes": 0, + "prompt_semantic_bytes": 0, + "refer_spec_bytes": 0, + "raw_audio_bytes": 0, + "audio_16k_bytes": 0, + } + for state in states: + spec_audio, audio_16k = state.refer_spec + item = { + "request_id": state.request_id, + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "phones_len": int(state.phones.shape[0]), + "all_phones_len": int(state.all_phones.shape[0]), + "bert_frames": int(state.all_bert_features.shape[-1]), + "phones_bytes": tensor_nbytes(state.phones), + "prompt_phones_bytes": tensor_nbytes(state.prompt_phones), + "all_phones_bytes": tensor_nbytes(state.all_phones), + "all_bert_features_bytes": tensor_nbytes(state.all_bert_features), + "prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic), + "refer_spec_bytes": tensor_nbytes(spec_audio), + "audio_16k_bytes": tensor_nbytes(audio_16k), + "raw_audio_bytes": tensor_nbytes(state.raw_audio), + } + for key in total: + total[key] += int(item[key]) + per_request.append(item) + total["total_bytes"] = int(sum(total.values())) + total["total_mb"] = bytes_to_mb(total["total_bytes"]) + return {"per_request": per_request, "total": total} + + +def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]: + y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences)) + fields = { + "x_bytes": tensor_nbytes(active_batch.x), + "x_lens_bytes": tensor_nbytes(active_batch.x_lens), + "prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens), + "xy_pos_bytes": tensor_nbytes(active_batch.xy_pos), + "key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask), + "prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask), + "y_sequence_bytes": y_sequence_bytes, + } + fields["total_bytes"] = int(sum(fields.values())) + fields["total_mb"] = bytes_to_mb(fields["total_bytes"]) + fields["batch_size"] = int(len(active_batch.states)) + fields["max_x_len"] = int(active_batch.x.shape[1]) + fields["src_len"] = int(active_batch.xy_pos.shape[1]) + fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape) + return fields + + +def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]: + per_request = [] + total_private_k_bytes = 0 + total_private_v_bytes = 0 + total_decode_mask_bytes = 0 + total_y_sequence_bytes = 0 + for item in running_requests: + k_bytes = tensor_list_nbytes(item.k_cache) + v_bytes = tensor_list_nbytes(item.v_cache) + mask_bytes = tensor_nbytes(item.decode_attn_mask) + y_bytes = tensor_nbytes(item.y_sequence) + total_private_k_bytes += k_bytes + total_private_v_bytes += v_bytes + total_decode_mask_bytes += mask_bytes + total_y_sequence_bytes += y_bytes + per_request.append( + { + "request_id": item.state.request_id, + "step_idx": int(item.step_idx), + "prefix_len": int(item.prefix_len), + "history_len": int(item.y_sequence.shape[0]), + "kv_len": int(item.k_cache[0].shape[1]), + "k_cache_bytes": k_bytes, + "v_cache_bytes": v_bytes, + "decode_mask_bytes": mask_bytes, + "y_sequence_bytes": y_bytes, + } + ) + total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes + return { + "per_request": per_request, + "totals": { + "private_k_cache_bytes": int(total_private_k_bytes), + "private_v_cache_bytes": int(total_private_v_bytes), + "private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes), + "decode_mask_bytes": int(total_decode_mask_bytes), + "y_sequence_bytes": int(total_y_sequence_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + }, + } + + +def summarise_decode_batch( + xy_pos: torch.Tensor, + batched_k_cache: Sequence[torch.Tensor], + batched_v_cache: Sequence[torch.Tensor], + batched_decode_attn_mask: Optional[torch.Tensor], + running_requests: Sequence[T2SRunningRequest], +) -> Dict[str, Any]: + private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests)) + private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests)) + batched_k_bytes = tensor_list_nbytes(batched_k_cache) + batched_v_bytes = tensor_list_nbytes(batched_v_cache) + batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask) + xy_pos_bytes = tensor_nbytes(xy_pos) + total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes + return { + "batch_size": int(len(running_requests)), + "xy_pos_bytes": int(xy_pos_bytes), + "batched_k_cache_bytes": int(batched_k_bytes), + "batched_v_cache_bytes": int(batched_v_bytes), + "batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes), + "batched_decode_mask_bytes": int(batched_mask_bytes), + "private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes), + "kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_pos_shape": list(xy_pos.shape), + "batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape), + "layer_k_cache_shape": list(batched_k_cache[0].shape), + } + + +def summarise_decode_outputs( + xy_dec: torch.Tensor, + next_k_cache: Sequence[torch.Tensor], + next_v_cache: Sequence[torch.Tensor], +) -> Dict[str, Any]: + xy_dec_bytes = tensor_nbytes(xy_dec) + next_k_bytes = tensor_list_nbytes(next_k_cache) + next_v_bytes = tensor_list_nbytes(next_v_cache) + total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes + return { + "xy_dec_bytes": int(xy_dec_bytes), + "next_k_cache_bytes": int(next_k_bytes), + "next_v_cache_bytes": int(next_v_bytes), + "next_kv_cache_bytes": int(next_k_bytes + next_v_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_dec_shape": list(xy_dec.shape), + "layer_next_k_cache_shape": list(next_k_cache[0].shape), + } + + +def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]: + ranking = [ + ("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]), + ("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]), + ("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]), + ("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]), + ("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]), + ("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]), + ("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]), + ] + ranking.sort(key=lambda item: item[1], reverse=True) + return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking] + + +def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]: + semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device) + phones = state.phones.unsqueeze(0).to(tts.configs.device) + audio_fragment = tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=state.prompt_semantic, + prompt_phones=state.prompt_phones, + refer_spec=state.refer_spec, + raw_audio=state.raw_audio, + raw_sr=state.raw_sr, + speed=1.0, + sample_steps=32, + ) + output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"] + return tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=1.0, + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + +def simulate_worker_end_to_end( + tts: TTS, + specs: Sequence[SchedulerRequestSpec], + max_steps: int, + rounds: int, + grad_mode: str = "default", +) -> Dict[str, Any]: + device = tts.configs.device + recorder = GlobalPeakRecorder(device) + recorder.record("after_model_load") + + state_map: Dict[str, T2SRequestState] = {} + per_round: List[Dict[str, Any]] = [] + + for round_index in range(rounds): + grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext + with grad_context(): + states = [prepare_request_state(tts, spec) for spec in specs] + state_map = {state.request_id: state for state in states} + recorder.record( + "after_prepare_states", + round_index=int(round_index), + request_count=int(len(states)), + grad_mode=grad_mode, + ) + + pending = list(states) + running_requests: List[T2SRunningRequest] = [] + round_events: List[Dict[str, Any]] = [] + current_tick = 0 + + while pending or running_requests: + admitted = pending + pending = [] + + if admitted: + recorder.record( + "before_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_count=int(len(admitted)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps) + recorder.record( + "after_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_running_count=int(len(admitted_running)), + admitted_finished_count=int(len(admitted_finished)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "prefill", + "admitted_count": int(len(admitted)), + "admitted_running_count": int(len(admitted_running)), + "admitted_finished_count": int(len(admitted_finished)), + } + ) + for item in admitted_finished: + recorder.record( + "before_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + running_requests.extend(admitted_running) + recorder.record( + "after_extend_running", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + + if running_requests: + recorder.record( + "before_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + running_requests, step_finished = run_decode_step_for_running( + tts.t2s_model.model, + running_requests, + max_steps=max_steps, + ) + recorder.record( + "after_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_count=int(len(step_finished)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "decode", + "running_count_after_decode": int(len(running_requests)), + "finished_count": int(len(step_finished)), + } + ) + for item in step_finished: + recorder.record( + "before_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + current_tick += 1 + + recorder.record( + "after_round_complete", + round_index=int(round_index), + running_count=0, + grad_mode=grad_mode, + ) + per_round.append( + { + "round_index": int(round_index), + "events": round_events, + } + ) + + return { + "grad_mode": grad_mode, + "rounds": per_round, + "timeline": recorder.summary(), + } + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + device = tts.configs.device + use_cuda = str(device).startswith("cuda") and torch.cuda.is_available() + set_seed(args.seed, use_cuda) + + specs = load_request_specs(args) + if args.early_stop_num == -1: + for spec in specs: + spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec) + + if args.warmup and specs: + warmup_spec = specs[:1] + _ = [prepare_request_state(tts, spec) for spec in warmup_spec] + gc.collect() + if use_cuda: + torch.cuda.empty_cache() + _sync_device(device) + + states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs]) + request_state_summary = summarise_state_tensors(states) + + active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states)) + prefill_batch_tensor_summary = summarise_prefill_batch(active_batch) + + prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps)) + running_requests, finished_items = prefill_result + running_requests_summary = summarise_running_requests(running_requests) + finished_after_prefill_summary = [ + { + "request_id": item.request_id, + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + } + for item in finished_items + ] + + if not running_requests: + raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}") + + decode_batch_result, decode_batch_mem = stage_run( + device, + lambda: _build_decode_batch_from_running(model, running_requests), + ) + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result + decode_batch_tensor_summary = summarise_decode_batch( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + running_requests, + ) + + decode_out_result, decode_step_mem = stage_run( + device, + lambda: model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ), + ) + xy_dec, next_k_cache, next_v_cache = decode_out_result + decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache) + del active_batch + del running_requests + del finished_items + del xy_pos + del batched_k_cache + del batched_v_cache + del batched_decode_attn_mask + del xy_dec + del next_k_cache + del next_v_cache + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + end_to_end_worker = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode=args.worker_grad_mode, + ) + live_cuda_tensors_after_worker = snapshot_live_cuda_tensors() + worker_inference_mode = None + if args.compare_worker_grad_modes: + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + worker_inference_mode = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode="inference_mode", + ) + + summary = { + "meta": { + "scenario": args.scenario if args.request_manifest is None else "manifest", + "seed": int(args.seed), + "device": str(device), + "dtype": str(next(model.parameters()).dtype), + "request_count": int(len(specs)), + "num_layers": int(model.num_layers), + "num_heads": int(model.num_head), + "model_dim": int(model.model_dim), + "model_weights_mb": bytes_to_mb(model_nbytes(model)), + }, + "loaded_module_weights": build_module_weight_summary(tts), + "requests": [ + { + "request_id": spec.request_id, + "ref_audio_path": str(spec.ref_audio_path), + "prompt_text": spec.prompt_text, + "text": spec.text, + } + for spec in specs + ], + "prepare_stage": { + "memory": prepare_mem, + "request_state": request_state_summary, + }, + "prefill_batch": { + "memory": prefill_batch_mem, + "tensor_bytes": prefill_batch_tensor_summary, + }, + "prefill_step": { + "memory": prefill_step_mem, + "running_requests": running_requests_summary, + "finished_after_prefill": finished_after_prefill_summary, + }, + "decode_batch": { + "memory": decode_batch_mem, + "tensor_bytes": decode_batch_tensor_summary, + }, + "decode_outputs": { + "memory": decode_step_mem, + "tensor_bytes": decode_output_tensor_summary, + }, + "end_to_end_worker": end_to_end_worker, + "live_cuda_tensors_after_worker": live_cuda_tensors_after_worker, + "end_to_end_worker_inference_mode": worker_inference_mode, + } + summary["top_rankings"] = top_rankings(summary) + + summary_path = args.output_dir / "t2s_memory_breakdown_summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + + print(json.dumps(summary["meta"], ensure_ascii=False, indent=2)) + print("[top_rankings]") + for item in summary["top_rankings"]: + print(f"- {item['name']}: {item['mb']:.3f} MB") + print("[worker_peak]") + print( + json.dumps( + { + "peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"], + "peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + if worker_inference_mode is not None: + print("[worker_peak_inference_mode]") + print( + json.dumps( + { + "peak_label": worker_inference_mode["timeline"]["peak_label"], + "peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + print(f"[summary] {summary_path}") + + +if __name__ == "__main__": + main() From 845b181360b5c5f4a5ed827ca01cfd5dfd8f58b1 Mon Sep 17 00:00:00 2001 From: baicai-1145 <3423714059@qq.com> Date: Mon, 9 Mar 2026 05:19:28 +0800 Subject: [PATCH 6/6] Implement batch processing for BERT and reference semantic tasks in TTS. Introduce StageLimiter for managing concurrent processing and enhance the TTS class with new methods for handling audio and semantic extraction. Update profiling metrics for better performance tracking during inference. --- GPT_SoVITS/TTS_infer_pack/TTS.py | 261 ++++++++++++++--- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 259 +++++++++++------ .../prepare_bert_batch_worker.py | 197 +++++++++++++ .../prepare_ref_semantic_batch_worker.py | 262 ++++++++++++++++++ GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py | 121 ++++++-- api_v3.py | 64 ++++- 6 files changed, 1024 insertions(+), 140 deletions(-) create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py create mode 100644 GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 2fd0df35..9c259662 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -5,6 +5,7 @@ import random import sys import time import traceback +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy import torchaudio @@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list from TTS_infer_pack.text_segmentation_method import splits -from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( + PrepareRefSemanticBatchWorker, + prepare_prompt_semantic_wav16k, +) from sv import SV resample_transform_dict = {} @@ -442,11 +448,56 @@ class TTS: "upsample_rate": None, "overlapped_len": None, } + self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2"))) + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None + default_text_cpu_workers = 16 + self.prepare_text_cpu_workers = max( + 0, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))), + ) + self.prepare_text_cpu_executor = None + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_executor = ThreadPoolExecutor( + max_workers=self.prepare_text_cpu_workers, + thread_name_prefix="prepare-text-cpu", + ) self._init_models() + if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": + self.prepare_bert_batch_worker = PrepareBertBatchWorker( + bert_model=self.bert_model, + tokenizer=self.bert_tokenizer, + device=self.configs.device, + stage_limiter=self.prepare_bert_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), + max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + ) + if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") + if ref_max_batch_samples is None: + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000") + self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker( + ssl_model=self.cnhuhbert_model, + vits_model=self.vits_model, + device=self.configs.device, + is_half=self.configs.is_half, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + stage_limiter=self.prepare_ref_audio_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")), + max_batch_samples=int(ref_max_batch_samples), + ) + self.text_preprocessor: TextPreprocessor = TextPreprocessor( - self.bert_model, self.bert_tokenizer, self.configs.device + self.bert_model, + self.bert_tokenizer, + self.configs.device, + bert_stage_limiter=self.prepare_bert_stage_limiter, + bert_batch_worker=self.prepare_bert_batch_worker, ) self.prompt_cache: dict = { @@ -755,47 +806,52 @@ class TTS: Args: ref_audio_path: str, the path of the reference audio. """ - self._set_prompt_semantic(ref_audio_path) - self._set_ref_spec(ref_audio_path) + bundle = self.extract_ref_audio_bundle(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [bundle["refer_spec"]] + else: + self.prompt_cache["refer_spec"][0] = bundle["refer_spec"] + self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"] + self.prompt_cache["raw_audio"] = bundle["raw_audio"] + self.prompt_cache["raw_sr"] = bundle["raw_sr"] self._set_ref_audio_path(ref_audio_path) - def extract_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - return prompt_semantic - - def extract_ref_spec(self, ref_audio_path: str): + def _load_ref_audio_raw(self, ref_audio_path: str): raw_audio, raw_sr = torchaudio.load(ref_audio_path) - raw_audio = raw_audio.to(self.configs.device).float() + return raw_audio.float(), int(raw_sr) + + @torch.inference_mode() + def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor): + wav16k = wav16k.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) + codes = self.vits_model.extract_latent(hubert_feature) + return codes[0, 0].to(self.configs.device) + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + return self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + + def extract_prompt_semantic(self, ref_wav_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) + return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + + def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + raw_audio_device = raw_audio.to(self.configs.device).float() if raw_sr != self.configs.sampling_rate: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) else: - audio = raw_audio.to(self.configs.device) + audio = raw_audio_device if audio.shape[0] == 2: audio = audio.mean(0).unsqueeze(0) @@ -820,8 +876,141 @@ class TTS: audio = None return spec, audio, raw_audio, raw_sr - def extract_text_features(self, text: str, language: str): - return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version) + def extract_ref_spec(self, ref_audio_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + return self._extract_ref_spec_from_raw(raw_audio, raw_sr) + + def extract_ref_audio_bundle(self, ref_audio_path: str): + load_start = time.perf_counter() + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + if self.prepare_ref_semantic_batch_worker is None: + with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: + prompt_semantic_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(limiter_stats["wait_ms"]) + audio_stage_slots = float(limiter_stats["slots"]) + audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) + prompt_semantic_profile = { + "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": prompt_semantic_ms, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + ref_spec_wait_ms = 0.0 + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0) + ), + "prompt_semantic_scatter_ms": float( + prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + "prompt_semantic_stage_slots": float( + prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0) + ), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": ref_spec_wait_ms, + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + prompt_semantic_profile = { + "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": 0.0, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": 0.0, + "prompt_semantic_stage_inflight_peak": 0.0, + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + if self.prepare_ref_semantic_batch_worker is not None: + prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + prompt_semantic_profile.update(worker_profile) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_limiter_stats["wait_ms"] + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_limiter_stats["slots"]), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_limiter_stats["peak_inflight"]), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + def extract_text_features(self, text: str, language: str, profile: dict | None = None): + return self.text_preprocessor.segment_and_extract_feature_for_text( + text, language, self.configs.version, profile=profile + ) def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 283e91c3..15b3c322 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,6 +1,8 @@ import os import sys import threading +import time +from contextlib import contextmanager from tqdm import tqdm @@ -16,6 +18,7 @@ from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker from tools.i18n.i18n import I18nAuto, scan_language_list @@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list: return result +class StageLimiter: + def __init__(self, slots: int): + self.slots = max(1, int(slots)) + self.semaphore = threading.BoundedSemaphore(self.slots) + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + + @contextmanager + def enter(self): + wait_start = time.perf_counter() + self.semaphore.acquire() + wait_ms = (time.perf_counter() - wait_start) * 1000.0 + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + peak_inflight = self.peak_inflight + try: + yield { + "wait_ms": wait_ms, + "inflight": current_inflight, + "peak_inflight": peak_inflight, + "slots": self.slots, + } + finally: + with self.lock: + self.inflight = max(0, self.inflight - 1) + self.semaphore.release() + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "slots": self.slots, + "inflight": self.inflight, + "peak_inflight": self.peak_inflight, + } + + class TextPreprocessor: - def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): + def __init__( + self, + bert_model: AutoModelForMaskedLM, + tokenizer: AutoTokenizer, + device: torch.device, + bert_stage_limiter: StageLimiter | None = None, + bert_batch_worker: PrepareBertBatchWorker | None = None, + ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device - self.bert_lock = threading.RLock() + self.bert_stage_limiter = bert_stage_limiter + self.bert_batch_worker = bert_batch_worker def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") @@ -115,86 +166,136 @@ class TextPreprocessor: return texts def segment_and_extract_feature_for_text( - self, text: str, language: str, version: str = "v1" + self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version) + return self.get_phones_and_bert(text, language, version, profile=profile) - def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): - with self.bert_lock: - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): + def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - # print(textlist) - # print(langlist) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True) + def get_phones_and_bert( + self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None + ): + text = re.sub(r' {2,}', ' ', text) + textlist, langlist = self._split_text_by_language(text, language) + phones_list = [] + bert_list = [] + norm_text_list = [] + for segment_text, segment_lang in zip(textlist, langlist): + phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version) + bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) - return phones, bert, norm_text + if not final and len(phones) < 6: + return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile) - def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: - with torch.no_grad(): - inputs = self.tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(self.device) - res = self.bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + return phones, bert, norm_text + + def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(profile.get(key, 0.0)) + float(value) + + def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: + if self.bert_batch_worker is not None: + feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak( + profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0) + ) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + return feature + + limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} + if self.bert_stage_limiter is None: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.bert_stage_limiter.enter() as limiter_stats: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"]) + self._accumulate_profile(profile, "bert_forward_ms", forward_ms) + self._accumulate_profile(profile, "bert_calls", 1.0) + self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"]) + if profile is not None: + profile["bert_stage_slots"] = float(limiter_stats["slots"]) assert len(word2ph) == len(text) phone_level_feature = [] for i in range(len(word2ph)): @@ -209,10 +310,10 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): + def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None): language = language.replace("all_", "") if language == "zh": - feature = self.get_bert_feature(norm_text, word2ph).to(self.device) + feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( (1024, len(phones)), @@ -236,4 +337,4 @@ class TextPreprocessor: punctuations = "".join(re.escape(p) for p in punctuation) pattern = f"([{punctuations}])([{punctuations}])+" result = re.sub(pattern, r"\1", text) - return result \ No newline at end of file + return result diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py new file mode 100644 index 00000000..b1ede3d8 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -0,0 +1,197 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import torch + + +@dataclass +class BertFeatureTask: + norm_text: str + word2ph: List[int] + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_feature: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareBertBatchWorker: + def __init__( + self, + bert_model, + tokenizer, + device, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 16, + max_batch_tokens: int = 4096, + ): + self.bert_model = bert_model + self.tokenizer = tokenizer + self.device = device + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_tokens = max(16, int(max_batch_tokens)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[BertFeatureTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) + self.worker_thread.start() + + def _estimate_task_tokens(self, task: BertFeatureTask) -> int: + return max(1, len(task.norm_text) + 2) + + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_feature is not None + return task.result_feature, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_tokens": self.max_batch_tokens, + } + + def _collect_batch(self) -> List[BertFeatureTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] + batch_tokens = self._estimate_task_tokens(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_tokens = self._estimate_task_tokens(next_task) + if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens: + break + batch.append(self.pending_tasks.popleft()) + batch_tokens += next_tokens + + self.active_batch_size = len(batch) + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + return batch + + def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _run_batch(self, batch: List[BertFeatureTask]) -> None: + batch_started = time.perf_counter() + texts = [task.norm_text for task in batch] + batch_tokens = sum(self._estimate_task_tokens(task) for task in batch) + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + hidden = outputs["hidden_states"][-3].detach().cpu() + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + text_len = len(task.word2ph) + if text_len != len(task.norm_text): + raise AssertionError( + f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}" + ) + seq_len = int(attention_mask_cpu[batch_index].sum().item()) + char_features = hidden[batch_index, 1 : seq_len - 1] + if char_features.shape[0] != text_len: + raise AssertionError( + f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}" + ) + phone_level_feature = [] + for char_index, repeat_count in enumerate(task.word2ph): + phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1)) + task.result_feature = torch.cat(phone_level_feature, dim=0).T + task.profile = { + "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_forward_ms": float(forward_ms), + "bert_tokenize_ms": float(tokenize_ms), + "bert_scatter_ms": 0.0, + "bert_calls": 1.0, + "bert_stage_slots": float(limiter_stats["slots"]), + "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "bert_batch_size": float(len(batch)), + "bert_batch_tokens": float(batch_tokens), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_feature is not None: + task.profile["bert_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py new file mode 100644 index 00000000..7a1f9a53 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -0,0 +1,262 @@ +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import librosa +import numpy as np +import torch + + +REF_AUDIO_MIN_SAMPLES_16K = 48000 +REF_AUDIO_MAX_SAMPLES_16K = 160000 + + +def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: + wav_mono = raw_audio + if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: + wav_mono = wav_mono.mean(0, keepdim=True) + wav16k = wav_mono.squeeze(0).cpu().numpy() + if raw_sr != 16000: + wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000) + if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: + raise OSError("参考音频在3~10秒范围外,请更换!") + wav16k = np.ascontiguousarray(wav16k, dtype=np.float32) + if zero_wav_samples > 0: + wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0) + return torch.from_numpy(wav16k) + + +def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: + if conv1d is None: + return input_lengths.to(dtype=torch.long) + kernel_size = int(conv1d.kernel_size[0]) + stride = int(conv1d.stride[0]) + padding = int(conv1d.padding[0]) + dilation = int(conv1d.dilation[0]) + output_lengths = torch.div( + input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1, + stride, + rounding_mode="floor", + ) + 1 + return torch.clamp(output_lengths, min=0).to(dtype=torch.long) + + +@dataclass +class RefSemanticTask: + raw_audio: torch.Tensor + raw_sr: int + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + done_event: threading.Event = field(default_factory=threading.Event) + result_prompt_semantic: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareRefSemanticBatchWorker: + def __init__( + self, + ssl_model, + vits_model, + device, + is_half: bool, + zero_wav_samples: int, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 8, + max_batch_samples: int = 960000, + ): + self.ssl_model = ssl_model + self.vits_model = vits_model + self.device = device + self.is_half = bool(is_half) + self.zero_wav_samples = max(0, int(zero_wav_samples)) + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[RefSemanticTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.active_batch_samples = 0 + self.active_batch_samples_peak = 0 + self.worker_thread = threading.Thread( + target=self._run_loop, + name="prepare-ref-semantic-batch-worker", + daemon=True, + ) + self.worker_thread.start() + + def _estimate_task_samples(self, task: RefSemanticTask) -> int: + raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0 + base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr)))) + return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples + + def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_prompt_semantic is not None + return task.result_prompt_semantic, dict(task.profile) + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "active_batch_samples": self.active_batch_samples, + "active_batch_samples_peak": self.active_batch_samples_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_samples": self.max_batch_samples, + } + + def _collect_batch(self) -> List[RefSemanticTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + batch: List[RefSemanticTask] = [self.pending_tasks.popleft()] + batch_samples = self._estimate_task_samples(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_samples = self._estimate_task_samples(next_task) + if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: + break + batch.append(self.pending_tasks.popleft()) + batch_samples += next_samples + + self.active_batch_size = len(batch) + self.active_batch_samples = batch_samples + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + if self.active_batch_samples > self.active_batch_samples_peak: + self.active_batch_samples_peak = self.active_batch_samples + return batch + + def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.active_batch_samples = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor: + model = self.ssl_model.model + if hasattr(model, "_get_feature_vector_attention_mask"): + feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask) + return feature_mask.to(dtype=torch.long).sum(dim=1) + raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1) + if hasattr(model, "_get_feat_extract_output_lengths"): + return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long) + return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) + + @torch.inference_mode() + def _run_batch(self, batch: List[RefSemanticTask]) -> None: + batch_started = time.perf_counter() + prepared_start = time.perf_counter() + prepared_wavs = [ + prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch + ] + cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0 + wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long) + batch_samples = int(wav_lengths.sum().item()) + max_wav_len = int(wav_lengths.max().item()) + + input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) + attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) + for batch_index, wav in enumerate(prepared_wavs): + wav_len = int(wav.shape[0]) + input_values_cpu[batch_index, :wav_len] = wav + attention_mask_cpu[batch_index, :wav_len] = 1 + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + forward_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + codes = self.vits_model.extract_latent(hubert_feature) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + code_len = int(code_lengths[batch_index].item()) + task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + task.profile = { + "prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_calls": 1.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": float(len(batch)), + "prompt_semantic_batch_samples": float(batch_samples), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_prompt_semantic is not None: + task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) + task.done_event.set() + + def _run_loop(self) -> None: + while True: + batch = self._collect_batch() + try: + self._run_batch(batch) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py index c8643991..de498573 100644 --- a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -1,5 +1,6 @@ from __future__ import annotations +from concurrent.futures import Future from dataclasses import dataclass from pathlib import Path import time @@ -123,31 +124,58 @@ def prepare_request_state( prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) text = spec.text.strip("\n") - _sync_device(device) - prompt_text_features_start = time.perf_counter() - prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang) - _sync_device(device) - prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile: Dict[str, float] = {} + text_features_profile: Dict[str, float] = {} + text_feature_pair_start = time.perf_counter() + prompt_future: Future | None = None + + def _extract_prompt_features(): + _sync_device(device) + prompt_start = time.perf_counter() + result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile) + _sync_device(device) + return result, (time.perf_counter() - prompt_start) * 1000.0 + + if getattr(tts, "prepare_text_cpu_executor", None) is not None: + prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features) _sync_device(device) text_features_start = time.perf_counter() - phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang) + phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile) _sync_device(device) text_features_ms = (time.perf_counter() - text_features_start) * 1000.0 + + if prompt_future is None: + _sync_device(device) + prompt_text_features_start = time.perf_counter() + prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features( + prompt_text, spec.prompt_lang, profile=prompt_text_profile + ) + _sync_device(device) + prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0 + prompt_text_profile["parallel_future_wait_ms"] = 0.0 + else: + prompt_wait_start = time.perf_counter() + (prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result() + prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0 + + text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0 if phones is None: raise ValueError(f"{spec.request_id} text preprocessing returned no phones") _sync_device(device) - prompt_semantic_start = time.perf_counter() - prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long() + ref_audio_bundle_start = time.perf_counter() + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + prompt_semantic = ref_audio_bundle["prompt_semantic"].long() + spec_audio, audio_16k = ref_audio_bundle["refer_spec"] + raw_audio = ref_audio_bundle["raw_audio"] + raw_sr = int(ref_audio_bundle["raw_sr"]) _sync_device(device) - prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 - - _sync_device(device) - ref_spec_start = time.perf_counter() - spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path)) - _sync_device(device) - ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0 + bundle_profile = ref_audio_bundle.get("profile", {}) + prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) + ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) + audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) _sync_device(device) tensorize_start = time.perf_counter() @@ -164,8 +192,43 @@ def prepare_request_state( prepare_profile = { "prompt_text_features_ms": prompt_text_features_ms, "text_features_ms": text_features_ms, + "prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)), + "text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)), + "text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)), + "text_feature_pair_ms": text_feature_pair_ms, + "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), + "audio_load_ms": audio_load_ms, + "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), + "audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)), + "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), "ref_spec_ms": ref_spec_ms, + "ref_audio_bundle_ms": ref_audio_bundle_ms, "tensorize_ms": tensorize_ms, "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, @@ -186,7 +249,7 @@ def prepare_request_state( prompt_semantic=prompt_semantic, refer_spec=(spec_audio, audio_16k), raw_audio=raw_audio, - raw_sr=int(raw_sr), + raw_sr=raw_sr, top_k=spec.top_k, top_p=spec.top_p, temperature=spec.temperature, @@ -409,9 +472,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: return F.pad(mask, (pad_len, 0), value=True) +def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor: + if mask.shape[-1] > target_len: + return mask[:, :, :, -target_len:] + if mask.shape[-1] < target_len: + return _pad_decode_mask_left(mask, target_len) + return mask + + def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + expected_mask_len = running_request.k_cache[0].shape[1] + 1 if running_request.decode_attn_mask is not None: - return running_request.decode_attn_mask + return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len) current_mask_len = running_request.k_cache[0].shape[1] + 1 return torch.zeros( (1, 1, 1, current_mask_len), @@ -481,17 +553,19 @@ def run_prefill_step( real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + request_decode_attn_mask = None + if decode_attn_mask is not None: + request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone() + request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1) + if not request_decode_attn_mask.any().item(): + request_decode_attn_mask = None running_requests.append( T2SRunningRequest( state=state, y_sequence=new_history, prefix_len=prefix_len, - decode_attn_mask=( - None - if decode_attn_mask is None - else decode_attn_mask[batch_index : batch_index + 1].clone() - ), + decode_attn_mask=request_decode_attn_mask, k_cache=request_k_cache, v_cache=request_v_cache, step_idx=1, @@ -603,6 +677,9 @@ def run_decode_step_for_running( batch_index : batch_index + 1, :, :, -current_decode_mask_len: ] next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) + next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1) + if not next_decode_attn_mask.any().item(): + next_decode_attn_mask = None next_running.append( T2SRunningRequest( state=running_request.state, diff --git a/api_v3.py b/api_v3.py index 9d250119..92f9a3b9 100644 --- a/api_v3.py +++ b/api_v3.py @@ -261,8 +261,9 @@ class SchedulerDebugWorker: self.tts = tts self.max_steps = max_steps self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0 - self.prepare_lock = threading.Lock() self.condition = threading.Condition() + self.prepare_inflight = 0 + self.prepare_peak_inflight = 0 self.pending_jobs: List[SchedulerPendingJob] = [] self.running_requests: List[T2SRunningRequest] = [] self.job_map: dict[str, SchedulerPendingJob] = {} @@ -282,8 +283,20 @@ class SchedulerDebugWorker: pass def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState: - with self.prepare_lock: - return prepare_request_state(self.tts, spec) + with self.condition: + self.prepare_inflight += 1 + prepare_inflight_on_enter = self.prepare_inflight + if self.prepare_inflight > self.prepare_peak_inflight: + self.prepare_peak_inflight = self.prepare_inflight + prepare_peak_inflight = self.prepare_peak_inflight + try: + state = prepare_request_state(self.tts, spec) + state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter) + state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight) + return state + finally: + with self.condition: + self.prepare_inflight = max(0, self.prepare_inflight - 1) def submit( self, @@ -363,9 +376,28 @@ class SchedulerDebugWorker: def get_state(self) -> dict: with self.condition: + bert_stage = self.tts.prepare_bert_stage_limiter.snapshot() + ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot() + bert_batch_worker = ( + None + if self.tts.prepare_bert_batch_worker is None + else self.tts.prepare_bert_batch_worker.snapshot() + ) + ref_semantic_batch_worker = ( + None + if self.tts.prepare_ref_semantic_batch_worker is None + else self.tts.prepare_ref_semantic_batch_worker.snapshot() + ) return { "pending_jobs": len(self.pending_jobs), "running_requests": len(self.running_requests), + "prepare_inflight": self.prepare_inflight, + "prepare_peak_inflight": self.prepare_peak_inflight, + "prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)), + "prepare_bert_stage": bert_stage, + "prepare_bert_batch_worker": bert_batch_worker, + "prepare_ref_audio_stage": ref_audio_stage, + "prepare_ref_semantic_batch_worker": ref_semantic_batch_worker, "tracked_jobs": len(self.job_map), "total_submitted": self.total_submitted, "total_finished": self.total_finished, @@ -907,10 +939,36 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): { "X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}", "X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Target-Bert-Batch-Size-Peak": str( + int(prepare_profile.get("text_bert_batch_size_peak", 0.0)) + ), + "X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}", + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}", + "X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}", "X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}", + "X-Prepare-Prompt-Semantic-Batch-Size": str( + int(prepare_profile.get("prompt_semantic_batch_size", 0.0)) + ), "X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}", + "X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}", + "X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}", "X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}", "X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}", + "X-Prepare-Inflight-On-Enter": str( + int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0)) + ), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), } ) return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers)