diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..570e9d7b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/g2pw-cu"] + path = third_party/g2pw-cu + url = https://github.com/baicai-1145/g2pw-cu.git diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ac905f4b..f55b7508 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module): blocks.append(block) self.t2s_transformer = T2STransformer(self.num_layers, blocks) + self.last_infer_stats = {} + + def _set_last_infer_stats(self, stats): + self.last_infer_stats = stats + + def get_last_infer_stats(self): + return dict(self.last_infer_stats) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): + requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True)) if prompts is None: + self._set_last_infer_stats( + { + "infer_mode": "batch_infer_prompt_free_fallback", + "requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": 0, + "generated_token_count_list": [], + } + ) print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") return self.infer_panel_naive_batched( x, @@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module): ) max_len = kwargs.get("max_len", x_lens.max()) + enable_mask_free_fastpath = requested_enable_mask_free_fastpath x_list = [] for x_item, bert_item in zip(x, bert_feature): # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) @@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module): y_list = [None] * y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None] * y.shape[0] + decode_attn_mask = attn_mask + prefill_after_mask_all_visible = None + fastpath_hit = False for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token( + xy_pos, k_cache, v_cache, decode_attn_mask + ) logits = self.ar_predict_layer(xy_dec[:, -1]) if idx == 0: attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False) + prefill_after_mask_all_visible = not attn_mask.any().item() + if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible: + decode_attn_mask = None + fastpath_hit = True + else: + decode_attn_mask = attn_mask else: - attn_mask = F.pad(attn_mask, (0, 1), value=False) + if decode_attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1), value=False) + decode_attn_mask = attn_mask if idx < 11: ###至少预测出10个token不然不给停止(0.4s) logits = logits[:, :-1] @@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module): if reserved_idx_of_batch_for_y is not None: # index = torch.LongTensor(batch_idx_map).to(y.device) y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + if decode_attn_mask is not None: + attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + decode_attn_mask = attn_mask if k_cache is not None: for i in range(len(k_cache)): k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) @@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module): if idx_list[i] is None: idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替 + self._set_last_infer_stats( + { + "infer_mode": "batch_infer", + "requested_enable_mask_free_fastpath": enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": prefill_after_mask_all_visible, + "fastpath_hit": fastpath_hit, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + "max_len": int(max_len), + } + ) if ref_free: return y_list, [0] * x.shape[0] # print(idx_list) @@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module): y_list.append(y[0]) idx_list.append(idx) + self._set_last_infer_stats( + { + "infer_mode": "naive_batched", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + } + ) return y_list, idx_list def infer_panel_naive( @@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module): if not streaming_mode: + generated_token_count = max(int(y.shape[1] - prefix_len), 0) + self._set_last_infer_stats( + { + "infer_mode": "naive", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(x.shape[0]), + "prefill_after_mask_all_visible": True if prompts is not None else None, + "fastpath_hit": True if prompts is not None else False, + "generated_token_count": generated_token_count, + "generated_token_count_list": [generated_token_count], + } + ) if ref_free: yield y, 0 yield y, idx diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index cc4f24d8..4b564ed8 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync( def logits_to_probs( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[int] = None, @@ -158,13 +159,27 @@ def logits_to_probs( # pdb.set_trace() if previous_tokens is not None and repetition_penalty != 1.0: previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=1, index=previous_tokens) - score = torch.where( - score < 0, - score * repetition_penalty, - score / repetition_penalty, - ) - logits.scatter_(dim=1, index=previous_tokens, src=score) + if previous_token_mask is None: + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=1, index=previous_tokens, src=score) + else: + previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device) + if previous_token_mask.any(): + batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens) + valid_batch_index = batch_index[previous_token_mask] + valid_token_index = previous_tokens[previous_token_mask] + score = logits[valid_batch_index, valid_token_index] + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits[valid_batch_index, valid_token_index] = score if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) @@ -192,9 +207,15 @@ def logits_to_probs( def sample( logits, previous_tokens: Optional[torch.Tensor] = None, + previous_token_mask: Optional[torch.Tensor] = None, **sampling_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs) + probs = logits_to_probs( + logits=logits, + previous_tokens=previous_tokens, + previous_token_mask=previous_token_mask, + **sampling_kwargs, + ) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c8344b0..16bc8db8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,4 +1,6 @@ import gc +import asyncio +import concurrent.futures import math import os import random @@ -7,19 +9,20 @@ import time import traceback from copy import deepcopy -import torchaudio -from tqdm import tqdm - now_dir = os.getcwd() sys.path.append(now_dir) -import os from typing import List, Tuple, Union +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + import ffmpeg import librosa import numpy as np import torch import torch.nn.functional as F +import torchaudio import yaml from AR.models.t2s_lightning_module import Text2SemanticLightningModule from BigVGAN.bigvgan import BigVGAN @@ -29,11 +32,18 @@ from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator from peft import LoraConfig, get_peft_model from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from transformers import AutoModelForMaskedLM, AutoTokenizer +from tqdm import tqdm from tools.audio_sr import AP_BWE from tools.i18n.i18n import I18nAuto, scan_language_list from TTS_infer_pack.text_segmentation_method import splits -from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.prepare_ref_semantic_batch_worker import ( + PrepareRefSemanticBatchWorker, + prepare_prompt_semantic_wav16k, +) +from TTS_infer_pack.prepare_text_cpu_worker import PrepareTextCpuWorker from sv import SV resample_transform_dict = {} @@ -315,7 +325,7 @@ class TTS_Config: assert isinstance(configs, dict) configs_ = deepcopy(self.default_configs) configs_.update(configs) - self.configs: dict = configs_.get("custom", configs_["v2"]) + self.configs: dict = configs_.get("custom", configs_["v2ProPlus"]) self.default_configs = deepcopy(configs_) self.device = self.configs.get("device", torch.device("cpu")) @@ -442,12 +452,20 @@ class TTS: "upsample_rate": None, "overlapped_len": None, } + self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1"))) + self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) + self.prepare_ref_audio_cpu_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_CPU_SLOTS", "8"))) + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None + self.prepare_text_cpu_workers = max( + 0, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", "0")), + ) + self.prepare_text_cpu_executor = None self._init_models() - - self.text_preprocessor: TextPreprocessor = TextPreprocessor( - self.bert_model, self.bert_tokenizer, self.configs.device - ) + self.refresh_runtime_components() self.prompt_cache: dict = { "ref_audio_path": None, @@ -464,6 +482,156 @@ class TTS: self.stop_flag: bool = False self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 + def refresh_runtime_components(self): + self.prepare_bert_batch_worker = None + self.prepare_ref_semantic_batch_worker = None + self.prepare_text_cpu_worker = None + if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0": + self.prepare_bert_batch_worker = PrepareBertBatchWorker( + bert_model=self.bert_model, + tokenizer=self.bert_tokenizer, + device=self.configs.device, + stage_limiter=self.prepare_bert_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")), + max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")), + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_BERT_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_ADMISSION_POLL_MS", "1")), + high_pressure_pending_threshold=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "0") + ), + high_pressure_batch_window_ms=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_BATCH_WINDOW_MS", "1") + ), + high_pressure_max_batch_items=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_ITEMS", "32") + ), + high_pressure_max_batch_tokens=int( + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_MAX_TOKENS", "8192") + ), + ) + if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0": + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES") + if ref_max_batch_samples is None: + ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000") + self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker( + ssl_model=self.cnhuhbert_model, + vits_model=self.vits_model, + device=self.configs.device, + is_half=self.configs.is_half, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + stage_limiter=self.prepare_ref_audio_stage_limiter, + batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")), + max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")), + max_batch_samples=int(ref_max_batch_samples), + ) + + self.text_preprocessor = TextPreprocessor( + self.bert_model, + self.bert_tokenizer, + self.configs.device, + version=self.configs.version, + bert_stage_limiter=self.prepare_bert_stage_limiter, + bert_batch_worker=self.prepare_bert_batch_worker, + ) + if self.prepare_text_cpu_workers > 0: + self.prepare_text_cpu_worker = PrepareTextCpuWorker( + process_fn=lambda text, language: self.text_preprocessor.preprocess_text_segments( + text, + language, + self.configs.version, + ), + worker_count=self.prepare_text_cpu_workers, + max_pending_tasks=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_PENDING_TASKS", "0")), + admission_poll_ms=int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_ADMISSION_POLL_MS", "1")), + admission_controller=self._build_text_cpu_admission_state, + ) + + @staticmethod + def _safe_queue_qsize(executor) -> int | None: + if executor is None: + return None + queue = getattr(executor, "_work_queue", None) + if queue is None or not hasattr(queue, "qsize"): + return None + try: + return int(queue.qsize()) + except Exception: + return None + + def snapshot_prepare_runtime_components(self) -> dict: + g2pw_runtime = None + try: + from text import chinese2 + + g2pw_instance = getattr(chinese2, "g2pw", None) + g2pw_backend = None if g2pw_instance is None else getattr(g2pw_instance, "_g2pw", None) + if g2pw_backend is not None and hasattr(g2pw_backend, "snapshot"): + g2pw_runtime = dict(g2pw_backend.snapshot()) + except Exception: + g2pw_runtime = None + return { + "text_cpu": { + "workers": int(self.prepare_text_cpu_workers), + "queue_size": self._safe_queue_qsize(self.prepare_text_cpu_executor), + "enabled": bool(self.prepare_text_cpu_worker is not None or self.prepare_text_cpu_executor is not None), + "worker": ( + None if self.prepare_text_cpu_worker is None else dict(self.prepare_text_cpu_worker.snapshot()) + ), + "admission": self._build_text_cpu_admission_state(), + }, + "bert": { + "stage_limiter": dict(self.prepare_bert_stage_limiter.snapshot()), + "batch_worker": ( + None if self.prepare_bert_batch_worker is None else dict(self.prepare_bert_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_bert_batch_worker is not None), + }, + "ref_semantic": { + "stage_limiter": dict(self.prepare_ref_audio_stage_limiter.snapshot()), + "batch_worker": ( + None + if self.prepare_ref_semantic_batch_worker is None + else dict(self.prepare_ref_semantic_batch_worker.snapshot()) + ), + "batching_enabled": bool(self.prepare_ref_semantic_batch_worker is not None), + }, + "text_preprocessor": ( + None if self.text_preprocessor is None or not hasattr(self.text_preprocessor, "snapshot") else self.text_preprocessor.snapshot() + ), + "g2pw": g2pw_runtime, + } + + def _build_text_cpu_admission_state(self) -> dict: + bert_pending_soft_max = max( + 0, + int( + os.environ.get( + "GPTSOVITS_PREPARE_TEXT_CPU_BERT_PENDING_SOFT_MAX", + os.environ.get("GPTSOVITS_PREPARE_BERT_HIGH_PRESSURE_PENDING_THRESHOLD", "32"), + ) + ), + ) + if self.prepare_bert_batch_worker is None or bert_pending_soft_max <= 0: + return { + "blocked": False, + "reason": "", + "bert_pending": 0, + "bert_active_batch_size": 0, + "bert_pending_soft_max": int(bert_pending_soft_max), + } + bert_state = dict(self.prepare_bert_batch_worker.snapshot()) + bert_pending = int(bert_state.get("pending", 0)) + bert_active_batch_size = int(bert_state.get("active_batch_size", 0)) + blocked = bert_pending >= bert_pending_soft_max + return { + "blocked": bool(blocked), + "reason": ("bert_pending" if blocked else ""), + "bert_pending": int(bert_pending), + "bert_active_batch_size": int(bert_active_batch_size), + "bert_pending_soft_max": int(bert_pending_soft_max), + } + def _init_models( self, ): @@ -755,10 +923,446 @@ class TTS: Args: ref_audio_path: str, the path of the reference audio. """ - self._set_prompt_semantic(ref_audio_path) - self._set_ref_spec(ref_audio_path) + bundle = self.extract_ref_audio_bundle(ref_audio_path) + if self.prompt_cache["refer_spec"] in [[], None]: + self.prompt_cache["refer_spec"] = [bundle["refer_spec"]] + else: + self.prompt_cache["refer_spec"][0] = bundle["refer_spec"] + self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"] + self.prompt_cache["raw_audio"] = bundle["raw_audio"] + self.prompt_cache["raw_sr"] = bundle["raw_sr"] self._set_ref_audio_path(ref_audio_path) + def _load_ref_audio_raw(self, ref_audio_path: str): + raw_audio, raw_sr = torchaudio.load(ref_audio_path) + return raw_audio.float(), int(raw_sr) + + @torch.inference_mode() + def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor): + wav16k = wav16k.to(self.configs.device) + if self.configs.is_half: + wav16k = wav16k.half() + hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) + codes = self.vits_model.extract_latent(hubert_feature) + return codes[0, 0].to(self.configs.device) + + @torch.inference_mode() + def _extract_prompt_semantic_profile_from_prepared_wav16k(self, wav16k: torch.Tensor): + forward_start = time.perf_counter() + prompt_semantic = self._extract_prompt_semantic_from_prepared_wav16k(wav16k) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + return prompt_semantic, forward_ms + + @torch.inference_mode() + def _prepare_prompt_semantic_wav16k_profile(self, raw_audio: torch.Tensor, raw_sr: int): + limiter = getattr(self, "prepare_ref_audio_cpu_limiter", None) + if limiter is None: + cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + return wav16k, cpu_prepare_ms, {"wait_ms": 0.0, "slots": 0.0, "peak_inflight": 0.0} + + with limiter.enter() as limiter_stats: + cpu_prepare_start = time.perf_counter() + wav16k = prepare_prompt_semantic_wav16k( + raw_audio=raw_audio, + raw_sr=raw_sr, + zero_wav_samples=int(self.configs.sampling_rate * 0.3), + ) + cpu_prepare_ms = (time.perf_counter() - cpu_prepare_start) * 1000.0 + return wav16k, cpu_prepare_ms, { + "wait_ms": float(limiter_stats.get("wait_ms", 0.0)), + "slots": float(limiter_stats.get("slots", 0.0)), + "peak_inflight": float(limiter_stats.get("peak_inflight", 0.0)), + } + + @torch.inference_mode() + def _extract_prompt_semantic_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + wav16k, cpu_prepare_ms, _ = self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) + prompt_semantic, forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) + return prompt_semantic, cpu_prepare_ms, forward_ms + + @torch.inference_mode() + def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + prompt_semantic, _, _ = self._extract_prompt_semantic_profile_from_raw(raw_audio, raw_sr) + return prompt_semantic + + def extract_prompt_semantic(self, ref_wav_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path) + return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr) + + def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + spec, audio, _, _, _ = self._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return spec, audio, raw_audio, raw_sr + + def _extract_ref_spec_profile_from_raw(self, raw_audio: torch.Tensor, raw_sr: int): + profile = { + "ref_spec_to_device_ms": 0.0, + "ref_spec_main_resample_ms": 0.0, + "ref_spec_norm_ms": 0.0, + "ref_spec_spectrogram_ms": 0.0, + "ref_spec_post_resample_ms": 0.0, + } + to_device_start = time.perf_counter() + raw_audio_device = raw_audio.to(self.configs.device).float() + profile["ref_spec_to_device_ms"] = (time.perf_counter() - to_device_start) * 1000.0 + + if raw_sr != self.configs.sampling_rate: + resample_start = time.perf_counter() + audio = raw_audio_device + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) + profile["ref_spec_main_resample_ms"] = (time.perf_counter() - resample_start) * 1000.0 + else: + audio = raw_audio_device + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + + norm_start = time.perf_counter() + maxx = audio.abs().max() + if maxx > 1: + audio /= min(2, maxx) + profile["ref_spec_norm_ms"] = (time.perf_counter() - norm_start) * 1000.0 + spec_start = time.perf_counter() + spec = spectrogram_torch( + audio, + self.configs.filter_length, + self.configs.sampling_rate, + self.configs.hop_length, + self.configs.win_length, + center=False, + ) + profile["ref_spec_spectrogram_ms"] = (time.perf_counter() - spec_start) * 1000.0 + if self.configs.is_half: + spec = spec.half() + if self.is_v2pro == True: + post_resample_start = time.perf_counter() + audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) + profile["ref_spec_post_resample_ms"] = (time.perf_counter() - post_resample_start) * 1000.0 + if self.configs.is_half: + audio = audio.half() + else: + audio = None + return spec, audio, raw_audio, raw_sr, profile + + def extract_ref_spec(self, ref_audio_path: str): + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + return self._extract_ref_spec_from_raw(raw_audio, raw_sr) + + def extract_ref_audio_bundle(self, ref_audio_path: str): + load_start = time.perf_counter() + raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + if self.prepare_ref_semantic_batch_worker is None: + wav16k, prompt_semantic_cpu_prepare_ms, prompt_semantic_cpu_limiter_stats = ( + self._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) + ) + with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats: + prompt_semantic_start = time.perf_counter() + prompt_semantic, prompt_semantic_forward_ms = self._extract_prompt_semantic_profile_from_prepared_wav16k( + wav16k + ) + prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0 + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(limiter_stats["wait_ms"]) + audio_stage_slots = float(limiter_stats["slots"]) + audio_stage_inflight_peak = float(limiter_stats["peak_inflight"]) + prompt_semantic_profile = { + "prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_cpu_prepare_wait_ms": float(prompt_semantic_cpu_limiter_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(prompt_semantic_cpu_limiter_stats.get("slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_cpu_limiter_stats.get("peak_inflight", 0.0) + ), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(limiter_stats["wait_ms"]), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_cpu_prepare_ms), + "prompt_semantic_forward_ms": float(prompt_semantic_forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + ref_spec_wait_ms = 0.0 + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_slots": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0) + ), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0) + ), + "prompt_semantic_scatter_ms": float( + prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + "prompt_semantic_stage_slots": float( + prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0) + ), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": ref_spec_wait_ms, + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + prompt_semantic_profile = { + "prompt_semantic_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_wait_ms": 0.0, + "prompt_semantic_cpu_prepare_slots": float(getattr(self.prepare_ref_audio_cpu_limiter, "slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": 0.0, + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": 0.0, + "prompt_semantic_batch_dispatch_delay_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": 0.0, + "prompt_semantic_forward_ms": 0.0, + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": 0.0, + "prompt_semantic_stage_inflight_peak": 0.0, + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + } + if self.prepare_ref_semantic_batch_worker is not None: + prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + prompt_semantic_profile.update(worker_profile) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_limiter_stats["wait_ms"] + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_limiter_stats["slots"]), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_limiter_stats["peak_inflight"]), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": load_ms, + "audio_stage_wait_ms": audio_stage_wait_ms, + "audio_stage_slots": audio_stage_slots, + "audio_stage_inflight_peak": audio_stage_inflight_peak, + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_slots": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_slots", 0.0) + ), + "prompt_semantic_cpu_prepare_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": ref_spec_ms, + "bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms, + }, + } + + async def extract_ref_audio_bundle_async(self, ref_audio_path: str): + if self.prepare_ref_semantic_batch_worker is None: + return await asyncio.to_thread(self.extract_ref_audio_bundle, ref_audio_path) + + load_start = time.perf_counter() + raw_audio, raw_sr = await asyncio.to_thread(self._load_ref_audio_raw, ref_audio_path) + load_ms = (time.perf_counter() - load_start) * 1000.0 + + prompt_semantic_task = asyncio.create_task( + self.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + + def _build_ref_spec_profile(): + with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats: + ref_spec_start = time.perf_counter() + refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2] + ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0 + return refer_spec, { + "ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]), + "ref_spec_ms": float(ref_spec_ms), + "audio_stage_slots": float(ref_spec_limiter_stats["slots"]), + "audio_stage_inflight_peak": float(ref_spec_limiter_stats["peak_inflight"]), + } + + ref_spec_task = asyncio.create_task(asyncio.to_thread(_build_ref_spec_profile)) + (prompt_semantic, prompt_semantic_profile), (refer_spec, ref_spec_profile) = await asyncio.gather( + prompt_semantic_task, + ref_spec_task, + ) + + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float( + ref_spec_profile.get("ref_spec_wait_ms", 0.0) + ) + audio_stage_slots = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(ref_spec_profile.get("audio_stage_slots", 0.0)), + ) + audio_stage_inflight_peak = max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(ref_spec_profile.get("audio_stage_inflight_peak", 0.0)), + ) + return { + "prompt_semantic": prompt_semantic, + "refer_spec": refer_spec, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": float(load_ms), + "audio_stage_wait_ms": float(audio_stage_wait_ms), + "audio_stage_slots": float(audio_stage_slots), + "audio_stage_inflight_peak": float(audio_stage_inflight_peak), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_worker_queue_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + prompt_semantic_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + prompt_semantic_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_ms": float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float(prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(ref_spec_profile.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profile.get("ref_spec_ms", 0.0)), + "bundle_total_ms": float(load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_profile.get("ref_spec_ms", 0.0)), + }, + } + + def extract_text_features(self, text: str, language: str, profile: dict | None = None): + return self.text_preprocessor.segment_and_extract_feature_for_text( + text, language, self.configs.version, profile=profile + ) + + def prepare_text_segments(self, text: str, language: str): + return self.text_preprocessor.preprocess_text_segments(text, language, self.configs.version) + + def resolve_g2pw_segments(self, prepared_segments, profile: dict | None = None): + return self.text_preprocessor.resolve_g2pw_segments(prepared_segments, profile=profile) + + def build_text_features_from_segments(self, prepared_segments, profile: dict | None = None): + return self.text_preprocessor.build_phones_and_bert_from_segments(prepared_segments, profile=profile) + + async def build_text_features_from_segments_async(self, prepared_segments, profile: dict | None = None): + return await self.text_preprocessor.build_phones_and_bert_from_segments_async( + prepared_segments, + profile=profile, + ) + + async def build_text_feature_pair_from_segments_async( + self, + prompt_segments, + target_segments, + prompt_profile: dict | None = None, + target_profile: dict | None = None, + ): + return await self.text_preprocessor.build_phones_and_bert_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + def _set_ref_audio_path(self, ref_audio_path): self.prompt_cache["ref_audio_path"] = ref_audio_path @@ -770,67 +1374,14 @@ class TTS: self.prompt_cache["refer_spec"][0] = spec_audio def _get_ref_spec(self, ref_audio_path): - raw_audio, raw_sr = torchaudio.load(ref_audio_path) - raw_audio = raw_audio.to(self.configs.device).float() + spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path) self.prompt_cache["raw_audio"] = raw_audio self.prompt_cache["raw_sr"] = raw_sr - - if raw_sr != self.configs.sampling_rate: - audio = raw_audio.to(self.configs.device) - if audio.shape[0] == 2: - audio = audio.mean(0).unsqueeze(0) - audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) - else: - audio = raw_audio.to(self.configs.device) - if audio.shape[0] == 2: - audio = audio.mean(0).unsqueeze(0) - - maxx = audio.abs().max() - if maxx > 1: - audio /= min(2, maxx) - spec = spectrogram_torch( - audio, - self.configs.filter_length, - self.configs.sampling_rate, - self.configs.hop_length, - self.configs.win_length, - center=False, - ) - if self.configs.is_half: - spec = spec.half() - if self.is_v2pro == True: - audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) - if self.configs.is_half: - audio = audio.half() - else: - audio = None return spec, audio def _set_prompt_semantic(self, ref_wav_path: str): - zero_wav = np.zeros( - int(self.configs.sampling_rate * 0.3), - dtype=np.float16 if self.configs.is_half else np.float32, - ) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: - raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - wav16k = wav16k.to(self.configs.device) - zero_wav_torch = zero_wav_torch.to(self.configs.device) - if self.configs.is_half: - wav16k = wav16k.half() - zero_wav_torch = zero_wav_torch.half() - - wav16k = torch.cat([wav16k, zero_wav_torch]) - hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose( - 1, 2 - ) # .float() - codes = self.vits_model.extract_latent(hubert_feature) - - prompt_semantic = codes[0, 0].to(self.configs.device) - self.prompt_cache["prompt_semantic"] = prompt_semantic + prompt_semantic = self.extract_prompt_semantic(ref_wav_path) + self.prompt_cache["prompt_semantic"] = prompt_semantic def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): seq = sequences[0] @@ -1227,6 +1778,9 @@ class TTS: ###### inference ###### t_34 = 0.0 t_45 = 0.0 + t2s_observe_batch_count = 0 + t2s_observe_fastpath_hits = 0 + t2s_observe_generated_tokens = 0 audio = [] is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] @@ -1280,6 +1834,29 @@ class TTS: ) t4 = time.perf_counter() t_34 += t4 - t3 + if hasattr(self.t2s_model.model, "get_last_infer_stats"): + t2s_stats = self.t2s_model.model.get_last_infer_stats() + if t2s_stats: + generated_token_count = int(t2s_stats.get("generated_token_count", 0)) + t2s_total_ms = (t4 - t3) * 1000.0 + avg_decode_ms_per_token = ( + t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0 + ) + t2s_observe_batch_count += 1 + t2s_observe_generated_tokens += generated_token_count + if bool(t2s_stats.get("fastpath_hit", False)): + t2s_observe_fastpath_hits += 1 + print( + "[t2s_observe] " + f"mode={t2s_stats.get('infer_mode')} " + f"batch_size={t2s_stats.get('batch_size')} " + f"tokens={generated_token_count} " + f"t2s_ms={t2s_total_ms:.3f} " + f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} " + f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} " + f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} " + f"fastpath_hit={t2s_stats.get('fastpath_hit')}" + ) batch_audio_fragment = [] @@ -1500,6 +2077,18 @@ class TTS: if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + if t2s_observe_batch_count > 0: + request_avg_decode_ms_per_token = ( + (t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0 + ) + print( + "[t2s_request_observe] " + f"batches={t2s_observe_batch_count} " + f"fastpath_hits={t2s_observe_fastpath_hits} " + f"generated_tokens={t2s_observe_generated_tokens} " + f"t2s_total_ms={t_34 * 1000.0:.3f} " + f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}" + ) if len(audio) == 0: yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return @@ -1575,19 +2164,19 @@ class TTS: self.init_sr_model() if not self.sr_model_not_exist: audio, sr = self.sr_model(audio.unsqueeze(0), sr) - max_audio = np.abs(audio).max() + if isinstance(audio, torch.Tensor): + max_audio = float(torch.abs(audio).max().item()) + else: + max_audio = float(np.abs(audio).max()) if max_audio > 1: audio /= max_audio - audio = (audio * 32768).astype(np.int16) t2 = time.perf_counter() print(f"超采样用时:{t2 - t1:.3f}s") + if isinstance(audio, torch.Tensor): + audio = audio.detach().float().cpu().numpy() else: - # audio = audio.float() * 32768 - # audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy() - - audio = audio.cpu().numpy() - - audio = (audio * 32768).astype(np.int16) + audio = np.asarray(audio) + audio = (audio.reshape(-1) * 32768).astype(np.int16) # try: @@ -1663,6 +2252,196 @@ class TTS: return audio + def using_vocoder_synthesis_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_audio_spec: torch.Tensor, + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device) + prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device) + refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device) + + fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio = raw_audio.to(self.configs.device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + tgt_sr = 24000 if self.configs.version == "v3" else 32000 + if raw_sr != tgt_sr: + ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device) + + mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4 + mel2 = mel_spec_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + T_ref = self.vocoder_configs["T_ref"] + T_chunk = self.vocoder_configs["T_chunk"] + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min + + mel2 = mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + + cfm_res = self.vits_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + + with torch.inference_mode(): + wav_gen = self.vocoder(cfm_res) + audio = wav_gen[0][0] + + return audio + + @torch.inference_mode() + def synthesize_audio_request_local( + self, + semantic_tokens: torch.Tensor, + phones: torch.Tensor, + prompt_semantic: torch.Tensor, + prompt_phones: torch.Tensor, + refer_spec: tuple | List[tuple], + raw_audio: torch.Tensor, + raw_sr: int, + speed: float = 1.0, + sample_steps: int = 32, + ): + refer_specs = list(refer_spec) if isinstance(refer_spec, list) else [refer_spec] + refer_audio_spec, audio_tensor = refer_specs[0] + if not self.configs.use_vocoder: + refer_audio_spec_list = [item[0].to(dtype=self.precision, device=self.configs.device) for item in refer_specs] + sv_emb = None + if self.is_v2pro: + sv_emb = [] + for _, audio_tensor_item in refer_specs: + if audio_tensor_item is None: + raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频")) + sv_emb.append(self.sv_model.compute_embedding3(audio_tensor_item).to(self.configs.device)) + return self.vits_model.decode( + semantic_tokens, + phones, + refer_audio_spec_list, + speed=speed, + sv_emb=sv_emb, + ).detach()[0, 0, :] + + return self.using_vocoder_synthesis_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=prompt_semantic, + prompt_phones=prompt_phones, + refer_audio_spec=refer_audio_spec, + raw_audio=raw_audio, + raw_sr=raw_sr, + speed=speed, + sample_steps=sample_steps, + ) + + @torch.inference_mode() + def synthesize_audio_requests_local_batched( + self, + semantic_tokens_list: List[torch.Tensor], + phones_list: List[torch.Tensor], + refer_specs: List[tuple | List[tuple]], + speeds: List[float] | None = None, + sample_steps_list: List[int] | None = None, + ) -> List[torch.Tensor]: + batch_size = len(semantic_tokens_list) + if batch_size == 0: + return [] + if len(phones_list) != batch_size or len(refer_specs) != batch_size: + raise ValueError("batched request-local synthesis 输入长度不一致") + if speeds is None: + speeds = [1.0] * batch_size + if sample_steps_list is None: + sample_steps_list = [32] * batch_size + if len(speeds) != batch_size or len(sample_steps_list) != batch_size: + raise ValueError("batched request-local synthesis 参数长度不一致") + first_speed = float(speeds[0]) + first_sample_steps = int(sample_steps_list[0]) + if any(abs(float(item) - first_speed) > 1e-6 for item in speeds): + raise ValueError("batched request-local synthesis 目前要求 speed 一致") + if any(int(item) != first_sample_steps for item in sample_steps_list): + raise ValueError("batched request-local synthesis 目前要求 sample_steps 一致") + if self.configs.use_vocoder: + raise NotImplementedError("request-local batched VITS synthesis 暂不支持 vocoder 模型") + + device = self.configs.device + max_semantic_len = max(int(item.shape[-1]) for item in semantic_tokens_list) + max_phone_len = max(int(item.shape[-1]) for item in phones_list) + semantic_batch = torch.zeros((1, batch_size, max_semantic_len), dtype=torch.long, device=device) + phone_batch = torch.zeros((batch_size, max_phone_len), dtype=torch.long, device=device) + semantic_lengths = [] + phone_lengths = [] + refer_audio_specs: List[torch.Tensor] = [] + sv_emb_batch = None + sv_emb_list: List[torch.Tensor] = [] + + for batch_index, semantic_tokens in enumerate(semantic_tokens_list): + semantic_len = int(semantic_tokens.shape[-1]) + phone_len = int(phones_list[batch_index].shape[-1]) + semantic_batch[0, batch_index, :semantic_len] = semantic_tokens.to(device=device, dtype=torch.long) + phone_batch[batch_index, :phone_len] = phones_list[batch_index].to(device=device, dtype=torch.long) + semantic_lengths.append(semantic_len) + phone_lengths.append(phone_len) + + refer_spec_item = refer_specs[batch_index] + refer_spec_group = list(refer_spec_item) if isinstance(refer_spec_item, list) else [refer_spec_item] + if len(refer_spec_group) != 1: + raise ValueError("batched request-local synthesis 暂不支持单请求多参考音频") + refer_audio_spec, audio_tensor = refer_spec_group[0] + refer_audio_specs.append(refer_audio_spec.to(dtype=self.precision, device=device)) + if self.is_v2pro: + if audio_tensor is None: + raise ValueError(i18n("v2Pro request-local batched synthesis 缺少 16k 参考音频")) + sv_emb_list.append(self.sv_model.compute_embedding3(audio_tensor).to(device)) + + if self.is_v2pro: + sv_emb_batch = torch.cat(sv_emb_list, dim=0) + + audio_batch, audio_lengths = self.vits_model.decode_batched_request_local( + codes=semantic_batch, + code_lengths=torch.LongTensor(semantic_lengths).to(device), + text=phone_batch, + text_lengths=torch.LongTensor(phone_lengths).to(device), + refer_list=refer_audio_specs, + speed=first_speed, + sv_emb=sv_emb_batch, + ) + audios: List[torch.Tensor] = [] + for batch_index in range(batch_size): + audio_len = int(audio_lengths[batch_index].item()) + audios.append(audio_batch[batch_index, 0, :audio_len].detach()) + return audios + def using_vocoder_synthesis_batched_infer( self, idx_list: List[int], diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 283e91c3..c30e3195 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,6 +1,10 @@ +import asyncio import os import sys import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass from tqdm import tqdm @@ -11,11 +15,13 @@ import re import torch from text.LangSegmenter import LangSegmenter from text import chinese -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method +from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker +from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload from tools.i18n.i18n import I18nAuto, scan_language_list @@ -49,12 +55,80 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list: return result +class StageLimiter: + def __init__(self, slots: int): + self.slots = max(1, int(slots)) + self.semaphore = threading.BoundedSemaphore(self.slots) + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + + @contextmanager + def enter(self): + wait_start = time.perf_counter() + self.semaphore.acquire() + wait_ms = (time.perf_counter() - wait_start) * 1000.0 + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + peak_inflight = self.peak_inflight + try: + yield { + "wait_ms": wait_ms, + "inflight": current_inflight, + "peak_inflight": peak_inflight, + "slots": self.slots, + } + finally: + with self.lock: + self.inflight = max(0, self.inflight - 1) + self.semaphore.release() + + def snapshot(self) -> Dict[str, int]: + with self.lock: + return { + "slots": self.slots, + "inflight": self.inflight, + "peak_inflight": self.peak_inflight, + } + + +@dataclass +class PreparedTextSegment: + language: str + phones: List[int] + word2ph: Optional[List[int]] + norm_text: str + needs_g2pw: bool = False + + class TextPreprocessor: - def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): + def __init__( + self, + bert_model: AutoModelForMaskedLM, + tokenizer: AutoTokenizer, + device: torch.device, + version: str = "v2", + bert_stage_limiter: StageLimiter | None = None, + bert_batch_worker: PrepareBertBatchWorker | None = None, + ): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device - self.bert_lock = threading.RLock() + self.version = str(version) + self.bert_stage_limiter = bert_stage_limiter + self.bert_batch_worker = bert_batch_worker + + def snapshot(self) -> Dict[str, object]: + return { + "device": str(self.device), + "bert_stage_limiter": ( + None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot()) + ), + "bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()), + } def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") @@ -98,7 +172,7 @@ class TextPreprocessor: # 解决输入目标文本的空行导致报错的问题 if len(text.strip()) == 0: continue - if not re.sub("\W+", "", text): + if not re.sub(r"\W+", "", text): # 检测一下,如果是纯符号,就跳过。 continue if text[-1] not in splits: @@ -115,86 +189,233 @@ class TextPreprocessor: return texts def segment_and_extract_feature_for_text( - self, text: str, language: str, version: str = "v1" + self, text: str, language: str, version: str = "v1", profile: Dict | None = None ) -> Tuple[list, torch.Tensor, str]: - return self.get_phones_and_bert(text, language, version) + prepared_segments = self.preprocess_text_segments(text, language, version) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) - def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): - with self.bert_lock: - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): + def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]: + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - # print(textlist) - # print(langlist) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True) + def get_phones_and_bert( + self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None + ): + prepared_segments = self.preprocess_text_segments(text, language, version, final=final) + return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile) - return phones, bert, norm_text + def preprocess_text_segments( + self, + text: str, + language: str, + version: str, + final: bool = False, + ) -> List[PreparedTextSegment]: + payloads = preprocess_text_segments_payload(text, language, version, final=final) + return [ + PreparedTextSegment( + language=str(payload["language"]), + phones=list(payload["phones"]), + word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]), + norm_text=str(payload["norm_text"]), + needs_g2pw=bool(payload.get("needs_g2pw", False)), + ) + for payload in payloads + ] - def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: - with torch.no_grad(): - inputs = self.tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(self.device) - res = self.bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + def resolve_g2pw_segments( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> List[PreparedTextSegment]: + zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)] + if not zh_indices: + return prepared_segments + from text import chinese2 + + normalized_segments = [prepared_segments[index].norm_text for index in zh_indices] + resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True) + self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0)) + self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0)) + self._accumulate_profile( + profile, + "g2pw_runtime_collect_wait_ms", + g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0), + ) + self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0)) + self._update_profile_peak( + profile, + "g2pw_runtime_batch_rows_peak", + g2pw_profile.get("g2pw_runtime_batch_rows", 0.0), + ) + self._update_profile_peak( + profile, + "g2pw_runtime_batch_requests_peak", + g2pw_profile.get("g2pw_runtime_batch_requests", 0.0), + ) + self._update_profile_peak( + profile, + "g2pw_runtime_pool_workers", + g2pw_profile.get("g2pw_runtime_pool_workers", 0.0), + ) + for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments): + prepared_segments[index] = PreparedTextSegment( + language=prepared_segments[index].language, + phones=list(cleaned_text_to_sequence(phones, self.version)), + word2ph=None if word2ph is None else list(word2ph), + norm_text=str(norm_text), + needs_g2pw=False, + ) + return prepared_segments + + def build_phones_and_bert_from_segments( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile) + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor] = [] + norm_text_list: List[str] = [] + for segment in prepared_segments: + bert = self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + return phones, bert, norm_text + + def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(profile.get(key, 0.0)) + float(value) + + def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None: + if profile is None: + return + profile[key] = float(max(float(profile.get(key, 0.0)), float(value))) + + def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None: + self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0)) + self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0)) + self._accumulate_profile( + profile, + "bert_batch_collect_wait_ms", + worker_profile.get("bert_batch_collect_wait_ms", 0.0), + ) + self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0)) + self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0)) + self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0)) + self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0)) + self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)) + self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0)) + self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0)) + self._update_profile_peak( + profile, + "bert_pending_depth_on_enqueue_peak", + worker_profile.get("bert_pending_depth_on_enqueue", 0.0), + ) + self._update_profile_peak( + profile, + "bert_pending_depth_on_collect_peak", + worker_profile.get("bert_pending_depth_on_collect", 0.0), + ) + self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0)) + if profile is not None: + profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0)) + profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0)) + + def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor: + if self.bert_batch_worker is not None: + feature, worker_profile = self.bert_batch_worker.submit(text, word2ph) + self._merge_bert_worker_profile(profile, worker_profile) + return feature + + limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0} + if self.bert_stage_limiter is None: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.bert_stage_limiter.enter() as limiter_stats: + forward_start = time.perf_counter() + with torch.no_grad(): + inputs = self.tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(self.device) + res = self.bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"]) + self._accumulate_profile(profile, "bert_forward_ms", forward_ms) + self._accumulate_profile(profile, "bert_calls", 1.0) + self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"]) + if profile is not None: + profile["bert_stage_slots"] = float(limiter_stats["slots"]) assert len(word2ph) == len(text) phone_level_feature = [] for i in range(len(word2ph)): @@ -209,10 +430,19 @@ class TextPreprocessor: phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): + def get_bert_inf( + self, + phones: list, + word2ph: Optional[list], + norm_text: str, + language: str, + profile: Dict | None = None, + ): language = language.replace("all_", "") if language == "zh": - feature = self.get_bert_feature(norm_text, word2ph).to(self.device) + if word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device) else: feature = torch.zeros( (1024, len(phones)), @@ -221,6 +451,115 @@ class TextPreprocessor: return feature + async def build_phones_and_bert_from_segments_async( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None = None, + ) -> Tuple[list, torch.Tensor, str]: + prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile) + segment_jobs = self._build_async_segment_jobs(prepared_segments, profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(item_profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(segment_jobs) + + def _build_async_segment_jobs( + self, + prepared_segments: List[PreparedTextSegment], + profile: Dict | None, + ) -> Dict[str, List]: + phones_list: List[List[int]] = [] + bert_list: List[torch.Tensor | None] = [] + norm_text_list: List[str] = [] + + for segment in prepared_segments: + phones_list.append(segment.phones) + norm_text_list.append(segment.norm_text) + segment_language = segment.language.replace("all_", "") + if segment_language == "zh" and self.bert_batch_worker is not None: + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + bert_list.append(None) + continue + bert_list.append( + self.get_bert_inf( + segment.phones, + segment.word2ph, + segment.norm_text, + segment.language, + profile=profile, + ) + ) + return { + "phones_list": phones_list, + "bert_list": bert_list, + "norm_text_list": norm_text_list, + } + + @staticmethod + def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]: + bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1) + phones = sum(segment_jobs["phones_list"], []) + norm_text = "".join(segment_jobs["norm_text_list"]) + return phones, bert, norm_text + + async def build_phones_and_bert_pair_from_segments_async( + self, + prompt_segments: List[PreparedTextSegment], + target_segments: List[PreparedTextSegment], + prompt_profile: Dict | None = None, + target_profile: Dict | None = None, + ) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]: + prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile) + target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile) + prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile) + target_jobs = self._build_async_segment_jobs(target_segments, target_profile) + pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = [] + + for segment_jobs, prepared_segments, profile in ( + (prompt_jobs, prompt_segments, prompt_profile), + (target_jobs, target_segments, target_profile), + ): + for segment_index, segment in enumerate(prepared_segments): + if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None: + continue + if segment.word2ph is None: + raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征") + pending_items.append( + ( + segment_jobs["bert_list"], + segment_index, + profile, + self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph), + ) + ) + + if pending_items: + pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items]) + for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results): + self._merge_bert_worker_profile(profile, worker_profile) + bert_list[bert_index] = feature.to(self.device) + + return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs) + def filter_text(self, texts): _text = [] if all(text in [None, " ", "\n", ""] for text in texts): @@ -236,4 +575,4 @@ class TextPreprocessor: punctuations = "".join(re.escape(p) for p in punctuation) pattern = f"([{punctuations}])([{punctuations}])+" result = re.sub(pattern, r"\1", text) - return result \ No newline at end of file + return result diff --git a/GPT_SoVITS/TTS_infer_pack/__init__.py b/GPT_SoVITS/TTS_infer_pack/__init__.py index 8579a632..09a257b2 100644 --- a/GPT_SoVITS/TTS_infer_pack/__init__.py +++ b/GPT_SoVITS/TTS_infer_pack/__init__.py @@ -1 +1,11 @@ -from . import TTS, text_segmentation_method +from __future__ import annotations + +import importlib + +__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"] + + +def __getattr__(name: str): + if name in __all__: + return importlib.import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py new file mode 100644 index 00000000..1ac77faa --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py @@ -0,0 +1,346 @@ +import asyncio +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import torch + + +@dataclass +class BertFeatureTask: + norm_text: str + word2ph: List[int] + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 + done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None + result_feature: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareBertBatchWorker: + def __init__( + self, + bert_model, + tokenizer, + device, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 16, + max_batch_tokens: int = 4096, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + high_pressure_pending_threshold: int = 0, + high_pressure_batch_window_ms: int | None = None, + high_pressure_max_batch_items: int | None = None, + high_pressure_max_batch_tokens: int | None = None, + ): + self.bert_model = bert_model + self.tokenizer = tokenizer + self.device = device + self.stage_limiter = stage_limiter + self.batch_window_ms = max(0, int(batch_window_ms)) + self.batch_window_s = float(self.batch_window_ms) / 1000.0 + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_tokens = max(16, int(max_batch_tokens)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + + self.high_pressure_pending_threshold = max( + 0, + int(high_pressure_pending_threshold) + if int(high_pressure_pending_threshold) > 0 + else max(self.max_batch_items * 2, 32), + ) + hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms) + hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items) + hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens) + self.high_pressure_batch_window_ms = max(0, hp_window_ms) + self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0 + self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items) + self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens) + + self.condition = threading.Condition() + self.pending_tasks: Deque[BertFeatureTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.active_batch_tokens = 0 + self.active_batch_tokens_peak = 0 + self.high_pressure_batches = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 + self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True) + self.worker_thread.start() + + def _estimate_task_tokens(self, task: BertFeatureTask) -> int: + return max(1, len(task.norm_text) + 2) + + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks + + def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + def _enqueue_task(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + with self.condition: + while not self._can_enqueue_locked(): + self.condition.wait(timeout=self.admission_poll_s) + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + + async def _enqueue_task_async(self, task: BertFeatureTask) -> None: + admission_started = time.perf_counter() + while True: + with self.condition: + if self._can_enqueue_locked(): + self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0) + return + await asyncio.sleep(self.admission_poll_s) + + def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph)) + self._enqueue_task(task) + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_feature is not None + return task.result_feature, dict(task.profile) + + async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = BertFeatureTask( + norm_text=str(norm_text), + word2ph=list(word2ph), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "active_batch_tokens": self.active_batch_tokens, + "active_batch_tokens_peak": self.active_batch_tokens_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_tokens": self.max_batch_tokens, + "max_pending_tasks": self.max_pending_tasks, + "high_pressure_pending_threshold": self.high_pressure_pending_threshold, + "high_pressure_batch_window_ms": self.high_pressure_batch_window_ms, + "high_pressure_max_batch_items": self.high_pressure_max_batch_items, + "high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens, + "high_pressure_batches": self.high_pressure_batches, + "admission_wait_total_ms": self.admission_wait_total_ms, + "admission_wait_peak_ms": self.admission_wait_peak_ms, + } + + def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]: + pending_depth = len(self.pending_tasks) + use_high_pressure = ( + self.high_pressure_pending_threshold > 0 + and pending_depth >= self.high_pressure_pending_threshold + ) + if use_high_pressure: + return ( + self.high_pressure_batch_window_s, + self.high_pressure_max_batch_items, + self.high_pressure_max_batch_tokens, + True, + pending_depth, + ) + return ( + self.batch_window_s, + self.max_batch_items, + self.max_batch_tokens, + False, + pending_depth, + ) + + def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + collect_started = time.perf_counter() + batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = ( + self._select_batch_policy_locked() + ) + batch: List[BertFeatureTask] = [self.pending_tasks.popleft()] + batch_tokens = self._estimate_task_tokens(batch[0]) + deadline = time.perf_counter() + batch_window_s + + while len(batch) < max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_tokens = self._estimate_task_tokens(next_task) + if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens: + break + batch.append(self.pending_tasks.popleft()) + batch_tokens += next_tokens + + self.active_batch_size = len(batch) + self.active_batch_tokens = batch_tokens + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + if self.active_batch_tokens > self.active_batch_tokens_peak: + self.active_batch_tokens_peak = self.active_batch_tokens + if use_high_pressure: + self.high_pressure_batches += 1 + return batch, { + "collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0, + "batch_tokens": float(batch_tokens), + "pending_depth_on_collect": float(pending_depth_on_collect), + "high_pressure_mode": 1.0 if use_high_pressure else 0.0, + "batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms), + } + + def _finalize_batch(self, batch: List[BertFeatureTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.active_batch_tokens = 0 + self.total_batches += 1 + self.total_finished += len(batch) + self.condition.notify_all() + + def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None: + batch_started = time.perf_counter() + texts = [task.norm_text for task in batch] + batch_tokens = int(batch_meta["batch_tokens"]) + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + if self.stage_limiter is None: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + tokenize_start = time.perf_counter() + inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0 + attention_mask_cpu = inputs["attention_mask"].cpu() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + forward_start = time.perf_counter() + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + + hidden = outputs["hidden_states"][-3].detach().cpu() + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + text_len = len(task.word2ph) + if text_len != len(task.norm_text): + raise AssertionError( + f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}" + ) + seq_len = int(attention_mask_cpu[batch_index].sum().item()) + char_features = hidden[batch_index, 1 : seq_len - 1] + if char_features.shape[0] != text_len: + raise AssertionError( + f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}" + ) + phone_level_feature = [] + for char_index, repeat_count in enumerate(task.word2ph): + phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1)) + task.result_feature = torch.cat(phone_level_feature, dim=0).T + task.profile = { + "bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]), + "bert_admission_wait_ms": float(task.admission_wait_ms), + "bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0), + "bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]), + "bert_forward_ms": float(forward_ms), + "bert_tokenize_ms": float(tokenize_ms), + "bert_scatter_ms": 0.0, + "bert_calls": 1.0, + "bert_stage_slots": float(limiter_stats["slots"]), + "bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "bert_batch_size": float(len(batch)), + "bert_batch_tokens": float(batch_tokens), + "bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]), + "bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]), + "bert_batch_window_ms": float(batch_meta["batch_window_ms"]), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_feature is not None: + task.profile["bert_scatter_ms"] = float(scatter_ms) + task.done_event.set() + self._notify_done_future(task) + + @staticmethod + def _resolve_done_future(task: BertFeatureTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_feature is not None + task.done_future.set_result((task.result_feature, dict(task.profile))) + + def _notify_done_future(self, task: BertFeatureTask) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + + def _run_loop(self) -> None: + while True: + batch, batch_meta = self._collect_batch() + try: + self._run_batch(batch, batch_meta) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + task.done_event.set() + self._notify_done_future(task) + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py new file mode 100644 index 00000000..e74e0de4 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py @@ -0,0 +1,1066 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import os +import threading +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + PreparedTextFeatures, + SchedulerRequestSpec, + T2SRequestState, + build_empty_text_features, + build_request_state_from_parts, + normalize_sentence, +) + + +@dataclass +class ProfiledResult: + result: Any + submit_at: float + started_at: float + finished_at: float + profile: Dict[str, float] | None = None + + @property + def queue_ms(self) -> float: + return max(0.0, (self.started_at - self.submit_at) * 1000.0) + + @property + def run_ms(self) -> float: + return max(0.0, (self.finished_at - self.started_at) * 1000.0) + + +@dataclass +class PreparedCpuStage: + spec: SchedulerRequestSpec + prepare_submit_at: float + prepare_start: float + prompt_text: str + text: str + prepare_admission_wait_ms: float + current_inflight: int + peak_inflight: int + prompt_cpu_profiled: ProfiledResult + target_cpu_profiled: ProfiledResult + + +class AsyncStageGate: + def __init__(self, max_inflight: int, poll_ms: int = 1): + self.max_inflight = max(0, int(max_inflight)) + self.lock = threading.Lock() + self.poll_s = max(0.0005, float(max(1, int(poll_ms))) / 1000.0) + self.inflight = 0 + self.peak_inflight = 0 + self.total_entered = 0 + self.total_wait_ms = 0.0 + self.wait_peak_ms = 0.0 + + async def acquire(self) -> Dict[str, float]: + wait_start = time.perf_counter() + while True: + with self.lock: + if self.max_inflight <= 0 or self.inflight < self.max_inflight: + self.inflight += 1 + self.total_entered += 1 + wait_ms = max(0.0, (time.perf_counter() - wait_start) * 1000.0) + self.total_wait_ms += float(wait_ms) + self.wait_peak_ms = max(self.wait_peak_ms, float(wait_ms)) + self.peak_inflight = max(self.peak_inflight, self.inflight) + return { + "wait_ms": float(wait_ms), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "max_inflight": float(self.max_inflight), + } + await asyncio.sleep(self.poll_s) + + def release(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, float]: + with self.lock: + return { + "max_inflight": float(self.max_inflight), + "inflight": float(self.inflight), + "peak_inflight": float(self.peak_inflight), + "total_entered": float(self.total_entered), + "total_wait_ms": float(self.total_wait_ms), + "wait_peak_ms": float(self.wait_peak_ms), + } + + +class PrepareCoordinator: + def __init__(self, tts: Any): + self.tts = tts + self.lock = threading.Lock() + self.inflight = 0 + self.peak_inflight = 0 + self.use_async_text_feature_path = bool( + getattr(tts, "prepare_bert_batch_worker", None) is not None + and os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_DIRECT", "0") != "0" + ) + self.max_inflight = max(0, int(os.environ.get("GPTSOVITS_PREPARE_MAX_INFLIGHT", "0"))) + gate_poll_ms = int(os.environ.get("GPTSOVITS_PREPARE_GATE_POLL_MS", "1")) + self._inflight_gate = AsyncStageGate(self.max_inflight, poll_ms=gate_poll_ms) + self.text_feature_workers = 0 + self.text_feature_executor = None + if not self.use_async_text_feature_path: + text_feature_default_workers = max(1, int(getattr(tts, "prepare_text_cpu_workers", 16) or 16)) + self.text_feature_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_WORKERS", str(text_feature_default_workers))), + ) + self.text_feature_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.text_feature_workers, + thread_name_prefix="prepare-text-feature", + ) + g2pw_default_workers = max(8, int(getattr(tts, "prepare_text_cpu_workers", 8) or 8)) + self.g2pw_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_G2PW_WORKERS", str(g2pw_default_workers))), + ) + self.g2pw_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.g2pw_workers, + thread_name_prefix="prepare-g2pw", + ) + ref_audio_default_workers = max(1, int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "4"))) + self.ref_audio_workers = max( + 1, + int(os.environ.get("GPTSOVITS_PREPARE_REF_ASYNC_WORKERS", str(ref_audio_default_workers))), + ) + self.ref_audio_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self.ref_audio_workers, + thread_name_prefix="prepare-ref-audio", + ) + text_cpu_gate_default = max(0, int(getattr(tts, "prepare_text_cpu_workers", 0) or 0)) + g2pw_gate_default = max(0, int(self.g2pw_workers)) + text_feature_gate_default = max(0, int(self.text_feature_workers)) + ref_audio_gate_default = max(0, int(self.ref_audio_workers)) + self.text_cpu_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_MAX_INFLIGHT", str(text_cpu_gate_default))), + poll_ms=gate_poll_ms, + ) + self.g2pw_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_G2PW_MAX_INFLIGHT", str(g2pw_gate_default))), + poll_ms=gate_poll_ms, + ) + self.text_feature_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_TEXT_FEATURE_MAX_INFLIGHT", str(text_feature_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_audio_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_load_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_LOAD_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + self.ref_spec_gate = AsyncStageGate( + int(os.environ.get("GPTSOVITS_PREPARE_REF_SPEC_MAX_INFLIGHT", str(ref_audio_gate_default))), + poll_ms=gate_poll_ms, + ) + + def _mark_enter(self) -> Tuple[int, int]: + with self.lock: + self.inflight += 1 + current_inflight = self.inflight + if current_inflight > self.peak_inflight: + self.peak_inflight = current_inflight + return current_inflight, self.peak_inflight + + def _mark_leave(self) -> None: + with self.lock: + self.inflight = max(0, self.inflight - 1) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + snapshot: Dict[str, Any] = { + "inflight": int(self.inflight), + "peak_inflight": int(self.peak_inflight), + "max_inflight": int(self.max_inflight), + "text_feature_workers": int(self.text_feature_workers), + "g2pw_workers": int(self.g2pw_workers), + "ref_audio_workers": int(self.ref_audio_workers), + } + runtime_snapshot_fn = getattr(self.tts, "snapshot_prepare_runtime_components", None) + if callable(runtime_snapshot_fn): + try: + snapshot["prepare_runtime_state"] = runtime_snapshot_fn() + except Exception: + snapshot["prepare_runtime_state"] = None + snapshot["prepare_stage_gates"] = { + "text_cpu": self.text_cpu_gate.snapshot(), + "g2pw": self.g2pw_gate.snapshot(), + "text_feature": self.text_feature_gate.snapshot(), + "ref_audio": self.ref_audio_gate.snapshot(), + "ref_load": self.ref_load_gate.snapshot(), + "ref_spec": self.ref_spec_gate.snapshot(), + } + return snapshot + + @staticmethod + def _run_profiled(fn, submit_at: float, *args) -> ProfiledResult: + started_at = time.perf_counter() + result = fn(*args) + finished_at = time.perf_counter() + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=float(started_at), + finished_at=float(finished_at), + ) + + def _prepare_text_cpu(self, text: str, language: str): + return self.tts.prepare_text_segments(text, language) + + def _resolve_g2pw_segments(self, prepared_segments): + profile: Dict[str, float] = {} + resolved_segments = self.tts.resolve_g2pw_segments(prepared_segments, profile=profile) + return resolved_segments, profile + + def _load_ref_audio_raw(self, ref_audio_path: str): + return self.tts._load_ref_audio_raw(ref_audio_path) + + def _build_ref_prompt_semantic_from_raw(self, raw_audio, raw_sr: int): + load_profile = {"audio_load_ms": 0.0} + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + prompt_semantic, worker_profile = self.tts.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + **load_profile, + "audio_stage_wait_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)), + "audio_stage_slots": float(worker_profile.get("prompt_semantic_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(worker_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_ms": float( + worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + + worker_profile.get("prompt_semantic_forward_ms", 0.0) + + worker_profile.get("prompt_semantic_scatter_ms", 0.0) + ), + **{key: float(value) for key, value in worker_profile.items()}, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(worker_profile.get("prompt_semantic_wait_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(worker_profile.get("prompt_semantic_scatter_ms", 0.0)), + }, + } + wav16k, cpu_prepare_ms, limiter_stats = self.tts._prepare_prompt_semantic_wav16k_profile(raw_audio, raw_sr) + with self.tts.prepare_ref_audio_stage_limiter.enter() as stage_stats: + prompt_semantic, forward_ms = self.tts._extract_prompt_semantic_profile_from_prepared_wav16k(wav16k) + return { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_ms": 0.0, + "audio_stage_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "audio_stage_slots": float(stage_stats.get("slots", 0.0)), + "audio_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float(limiter_stats.get("wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(limiter_stats.get("slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float(limiter_stats.get("peak_inflight", 0.0)), + "prompt_semantic_worker_queue_wait_ms": 0.0, + "prompt_semantic_batch_collect_wait_ms": 0.0, + "prompt_semantic_stage_limiter_wait_ms": float(stage_stats.get("wait_ms", 0.0)), + "prompt_semantic_batch_dispatch_delay_ms": 0.0, + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": 0.0, + "prompt_semantic_h2d_ms": 0.0, + "prompt_semantic_ssl_forward_ms": 0.0, + "prompt_semantic_hidden_length_ms": 0.0, + "prompt_semantic_extract_latent_ms": 0.0, + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_stage_slots": float(stage_stats.get("slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(stage_stats.get("peak_inflight", 0.0)), + "prompt_semantic_batch_size": 1.0, + "prompt_semantic_batch_samples": 0.0, + "ref_spec_wait_ms": 0.0, + "ref_spec_ms": 0.0, + "bundle_total_ms": float(cpu_prepare_ms + forward_ms + stage_stats.get("wait_ms", 0.0)), + }, + } + + def _extract_ref_spec_from_raw(self, raw_audio, raw_sr: int): + spec, audio, _, _, profile = self.tts._extract_ref_spec_profile_from_raw(raw_audio, raw_sr) + return (spec, audio), profile + + @staticmethod + def _build_empty_text_features_like(reference: PreparedTextFeatures | None = None) -> PreparedTextFeatures: + feature_dim = 1024 + dtype = None + if reference is not None: + try: + feature_dim = int(reference.bert_features.shape[0]) + dtype = reference.bert_features.dtype + except Exception: + pass + return build_empty_text_features( + feature_dim=int(feature_dim), + dtype=(dtype if dtype is not None else None) or __import__("torch").float32, + ) + + def _build_text_features( + self, + prepared_segments, + language: str, + cpu_run_ms: float, + base_profile: Dict[str, float] | None = None, + ) -> PreparedTextFeatures: + profile: Dict[str, float] = dict(base_profile or {}) + profile["cpu_preprocess_ms"] = float(cpu_run_ms) + branch_start = time.perf_counter() + phones, bert_features, norm_text = self.tts.build_text_features_from_segments(prepared_segments, profile=profile) + total_ms = float(cpu_run_ms + (time.perf_counter() - branch_start) * 1000.0) + profile["bert_total_ms"] = max(0.0, total_ms - float(cpu_run_ms)) + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=total_ms, + cpu_preprocess_ms=float(cpu_run_ms), + ) + + async def _run_on_executor(self, executor, fn, *args) -> ProfiledResult: + loop = asyncio.get_running_loop() + submit_at = time.perf_counter() + return await loop.run_in_executor(executor, self._run_profiled, fn, float(submit_at), *args) + + async def _run_text_cpu_stage(self, text: str, language: str) -> ProfiledResult: + await self.text_cpu_gate.acquire() + if text in [None, ""]: + try: + submit_at = time.perf_counter() + return ProfiledResult(result=[], submit_at=submit_at, started_at=submit_at, finished_at=submit_at) + finally: + self.text_cpu_gate.release() + text_cpu_worker = getattr(self.tts, "prepare_text_cpu_worker", None) + executor = getattr(self.tts, "prepare_text_cpu_executor", None) + try: + if text_cpu_worker is not None: + submit_at = time.perf_counter() + result, worker_profile = await text_cpu_worker.submit_async(text, language) + started_at = float( + submit_at + + ( + float(worker_profile.get("text_cpu_admission_wait_ms", 0.0)) + + float(worker_profile.get("text_cpu_queue_wait_ms", 0.0)) + ) + / 1000.0 + ) + finished_at = float(started_at + float(worker_profile.get("text_cpu_run_ms", 0.0)) / 1000.0) + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=finished_at, + profile=dict(worker_profile), + ) + if executor is None: + submit_at = time.perf_counter() + return self._run_profiled(self._prepare_text_cpu, submit_at, text, language) + return await self._run_on_executor(executor, self._prepare_text_cpu, text, language) + finally: + self.text_cpu_gate.release() + + async def _run_text_feature_stage(self, prepared_segments, language: str, cpu_run_ms: float) -> ProfiledResult: + await self.text_feature_gate.acquire() + try: + return await self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + prepared_segments, + language, + cpu_run_ms, + None, + ) + finally: + self.text_feature_gate.release() + + async def _run_g2pw_stage(self, prepared_segments) -> ProfiledResult: + has_pending = any(bool(getattr(segment, "needs_g2pw", False)) for segment in (prepared_segments or [])) + if not has_pending: + submit_at = time.perf_counter() + return ProfiledResult( + result=prepared_segments, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + profile={}, + ) + await self.g2pw_gate.acquire() + try: + profiled = await self._run_on_executor(self.g2pw_executor, self._resolve_g2pw_segments, prepared_segments) + result, stage_profile = profiled.result + return ProfiledResult( + result=result, + submit_at=float(profiled.submit_at), + started_at=float(profiled.started_at), + finished_at=float(profiled.finished_at), + profile=dict(stage_profile), + ) + finally: + self.g2pw_gate.release() + + async def _run_g2pw_pair_stage(self, prompt_segments, target_segments) -> tuple[ProfiledResult, ProfiledResult]: + prompt_is_empty = len(prompt_segments or []) == 0 + target_task = asyncio.create_task(self._run_g2pw_stage(target_segments)) + if not prompt_is_empty: + prompt_task = asyncio.create_task(self._run_g2pw_stage(prompt_segments)) + return await asyncio.gather(prompt_task, target_task) + target_profiled = await target_task + submit_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_segments, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + profile={}, + ) + return prompt_profiled, target_profiled + + @staticmethod + def _estimate_text_feature_run_ms(profile: Dict[str, float]) -> float: + return float( + profile.get("bert_wait_ms", 0.0) + + profile.get("bert_tokenize_ms", 0.0) + + profile.get("bert_forward_ms", 0.0) + + profile.get("bert_scatter_ms", 0.0) + ) + + async def _run_text_feature_pair_stage( + self, + prompt_segments, + target_segments, + prompt_cpu_run_ms: float, + target_cpu_run_ms: float, + prompt_base_profile: Dict[str, float] | None = None, + target_base_profile: Dict[str, float] | None = None, + ) -> tuple[ProfiledResult, ProfiledResult]: + prompt_is_empty = len(prompt_segments or []) == 0 + if self.text_feature_executor is not None: + target_feature_task = asyncio.create_task( + self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + target_segments, + None, + target_cpu_run_ms, + target_base_profile, + ) + ) + if not prompt_is_empty: + prompt_feature_task = asyncio.create_task( + self._run_on_executor( + self.text_feature_executor, + self._build_text_features, + prompt_segments, + None, + prompt_cpu_run_ms, + prompt_base_profile, + ) + ) + return await asyncio.gather(prompt_feature_task, target_feature_task) + target_profiled = await target_feature_task + submit_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=self._build_empty_text_features_like(target_profiled.result), + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + ) + return prompt_profiled, target_profiled + + await self.text_feature_gate.acquire() + target_profile: Dict[str, float] = dict(target_base_profile or {}) + target_profile["cpu_preprocess_ms"] = float(target_cpu_run_ms) + submit_at = time.perf_counter() + started_at = float(submit_at) + try: + if prompt_is_empty: + target_result_raw = await self.tts.build_text_features_from_segments_async( + target_segments, + profile=target_profile, + ) + prompt_result = self._build_empty_text_features_like( + PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + ) + finished_at = time.perf_counter() + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=float(submit_at), + finished_at=float(submit_at), + ) + target_result = PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > target_profiled.finished_at: + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + + prompt_profile: Dict[str, float] = dict(prompt_base_profile or {}) + prompt_profile["cpu_preprocess_ms"] = float(prompt_cpu_run_ms) + prompt_result_raw, target_result_raw = await self.tts.build_text_feature_pair_from_segments_async( + prompt_segments, + target_segments, + prompt_profile=prompt_profile, + target_profile=target_profile, + ) + finished_at = time.perf_counter() + + prompt_result = PreparedTextFeatures( + phones=prompt_result_raw[0], + bert_features=prompt_result_raw[1], + norm_text=prompt_result_raw[2], + profile=prompt_profile, + total_ms=float(prompt_cpu_run_ms + self._estimate_text_feature_run_ms(prompt_profile)), + cpu_preprocess_ms=float(prompt_cpu_run_ms), + ) + target_result = PreparedTextFeatures( + phones=target_result_raw[0], + bert_features=target_result_raw[1], + norm_text=target_result_raw[2], + profile=target_profile, + total_ms=float(target_cpu_run_ms + self._estimate_text_feature_run_ms(target_profile)), + cpu_preprocess_ms=float(target_cpu_run_ms), + ) + prompt_profiled = ProfiledResult( + result=prompt_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(prompt_profile) / 1000.0), + ) + target_profiled = ProfiledResult( + result=target_result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(submit_at + self._estimate_text_feature_run_ms(target_profile) / 1000.0), + ) + if finished_at > prompt_profiled.finished_at: + prompt_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(prompt_profile), + (finished_at - submit_at) * 1000.0, + ) + target_result.profile["bert_total_ms"] = max( + self._estimate_text_feature_run_ms(target_profile), + (finished_at - submit_at) * 1000.0, + ) + else: + prompt_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(prompt_profile) + target_result.profile["bert_total_ms"] = self._estimate_text_feature_run_ms(target_profile) + return prompt_profiled, target_profiled + finally: + self.text_feature_gate.release() + + async def _run_ref_prompt_semantic_stage(self, ref_audio_path: str) -> ProfiledResult: + if getattr(self.tts, "prepare_ref_semantic_batch_worker", None) is not None: + submit_at = time.perf_counter() + started_at = float(submit_at) + + await self.ref_load_gate.acquire() + try: + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + finally: + self.ref_load_gate.release() + + raw_audio, raw_sr = load_profiled.result + prompt_semantic_task = asyncio.create_task( + self.tts.prepare_ref_semantic_batch_worker.submit_async(raw_audio, raw_sr) + ) + prompt_semantic, prompt_semantic_profile = await prompt_semantic_task + limiter_snapshot = ( + self.tts.prepare_ref_audio_stage_limiter.snapshot() + if getattr(self.tts, "prepare_ref_audio_stage_limiter", None) is not None + else {} + ) + prompt_semantic_ms = ( + float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)) + + float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)) + ) + finished_at = time.perf_counter() + result = { + "prompt_semantic": prompt_semantic, + "raw_audio": raw_audio, + "raw_sr": raw_sr, + "profile": { + "audio_load_queue_ms": float(load_profiled.queue_ms), + "audio_load_ms": float(load_profiled.run_ms), + "audio_stage_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "audio_stage_slots": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + float(limiter_snapshot.get("slots", 0.0)), + ) + ), + "audio_stage_inflight_peak": float( + max( + float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + float(limiter_snapshot.get("peak_inflight", 0.0)), + ) + ), + "prompt_semantic_ms": float(prompt_semantic_ms), + "prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_ms": float( + prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0) + ), + "prompt_semantic_pack_ms": float(prompt_semantic_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(prompt_semantic_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float( + prompt_semantic_profile.get("prompt_semantic_ssl_forward_ms", 0.0) + ), + "prompt_semantic_hidden_length_ms": float( + prompt_semantic_profile.get("prompt_semantic_hidden_length_ms", 0.0) + ), + "prompt_semantic_extract_latent_ms": float( + prompt_semantic_profile.get("prompt_semantic_extract_latent_ms", 0.0) + ), + "prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float( + prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0) + ), + "prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)), + "prompt_semantic_batch_samples": float( + prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0) + ), + "bundle_total_ms": float( + load_profiled.queue_ms + + load_profiled.run_ms + + prompt_semantic_ms + ), + }, + } + return ProfiledResult( + result=result, + submit_at=float(submit_at), + started_at=started_at, + finished_at=float(finished_at), + ) + + await self.ref_audio_gate.acquire() + try: + load_profiled = await self._run_on_executor(self.ref_audio_executor, self._load_ref_audio_raw, ref_audio_path) + raw_audio, raw_sr = load_profiled.result + submit_at = time.perf_counter() + started_at = time.perf_counter() + result = await asyncio.to_thread(self._build_ref_prompt_semantic_from_raw, raw_audio, raw_sr) + result.setdefault("profile", {}) + result["profile"]["audio_load_queue_ms"] = float(load_profiled.queue_ms) + result["profile"]["audio_load_ms"] = float(load_profiled.run_ms) + finished_at = time.perf_counter() + return ProfiledResult(result=result, submit_at=float(submit_at), started_at=float(started_at), finished_at=float(finished_at)) + finally: + self.ref_audio_gate.release() + + async def _run_ref_spec_stage(self, raw_audio, raw_sr: int) -> ProfiledResult: + await self.ref_spec_gate.acquire() + try: + return await self._run_on_executor(self.ref_audio_executor, self._extract_ref_spec_from_raw, raw_audio, raw_sr) + finally: + self.ref_spec_gate.release() + + def _release_split_stage_slot(self) -> None: + self._mark_leave() + self._inflight_gate.release() + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + admission_start = time.perf_counter() + admission_stats = await self._inflight_gate.acquire() + prepare_admission_wait_ms = max( + float(admission_stats.get("wait_ms", 0.0)), + (time.perf_counter() - admission_start) * 1000.0, + ) + current_inflight, peak_inflight = self._mark_enter() + prepare_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + try: + prompt_cpu_task = asyncio.create_task(self._run_text_cpu_stage(prompt_text, spec.prompt_lang)) + target_cpu_task = asyncio.create_task(self._run_text_cpu_stage(text, spec.text_lang)) + prompt_cpu_profiled, target_cpu_profiled = await asyncio.gather(prompt_cpu_task, target_cpu_task) + return PreparedCpuStage( + spec=spec, + prepare_submit_at=float(prepare_submit_at), + prepare_start=float(prepare_start), + prompt_text=prompt_text, + text=text, + prepare_admission_wait_ms=float(prepare_admission_wait_ms), + current_inflight=int(current_inflight), + peak_inflight=int(peak_inflight), + prompt_cpu_profiled=prompt_cpu_profiled, + target_cpu_profiled=target_cpu_profiled, + ) + except Exception: + self._release_split_stage_slot() + raise + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + phase_one = await self._prepare_gpu_phase_one(cpu_stage) + phase_two = await self._prepare_gpu_phase_two(cpu_stage, phase_one) + return self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 0.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one["phase_wall_ms"]), + "engine_prepare_audio_phase_batch_size": 1.0, + "engine_prepare_text_phase_wall_ms": float(phase_two["phase_wall_ms"]), + "engine_prepare_text_phase_batch_size": 1.0, + }, + ) + finally: + self._release_split_stage_slot() + + async def _prepare_gpu_phase_one(self, cpu_stage: PreparedCpuStage) -> Dict[str, Any]: + phase_start = time.perf_counter() + g2pw_pair_task = asyncio.create_task( + self._run_g2pw_pair_stage( + cpu_stage.prompt_cpu_profiled.result, + cpu_stage.target_cpu_profiled.result, + ) + ) + ref_audio_task = asyncio.create_task(self._run_ref_prompt_semantic_stage(str(cpu_stage.spec.ref_audio_path))) + prompt_g2pw_profiled, target_g2pw_profiled = await g2pw_pair_task + g2pw_pair_end = time.perf_counter() + ref_audio_profiled = await ref_audio_task + phase_end = time.perf_counter() + return { + "prompt_g2pw_profiled": prompt_g2pw_profiled, + "target_g2pw_profiled": target_g2pw_profiled, + "ref_audio_profiled": ref_audio_profiled, + "ref_spec_result": None, + "g2pw_pair_ms": max(0.0, (g2pw_pair_end - phase_start) * 1000.0), + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + async def _prepare_gpu_phase_two( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + ) -> Dict[str, Any]: + phase_start = time.perf_counter() + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + prompt_feature_profiled, target_feature_profiled = await self._run_text_feature_pair_stage( + prompt_g2pw_profiled.result, + target_g2pw_profiled.result, + cpu_stage.prompt_cpu_profiled.run_ms, + cpu_stage.target_cpu_profiled.run_ms, + prompt_base_profile=dict(prompt_g2pw_profiled.profile or {}), + target_base_profile=dict(target_g2pw_profiled.profile or {}), + ) + phase_end = time.perf_counter() + return { + "prompt_feature_profiled": prompt_feature_profiled, + "target_feature_profiled": target_feature_profiled, + "phase_wall_ms": max(0.0, (phase_end - phase_start) * 1000.0), + } + + def _build_gpu_prepare_result( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + prompt_g2pw_profiled = phase_one["prompt_g2pw_profiled"] + target_g2pw_profiled = phase_one["target_g2pw_profiled"] + ref_audio_profiled = phase_one["ref_audio_profiled"] + ref_spec_result = phase_one.get("ref_spec_result") + prompt_feature_profiled = phase_two["prompt_feature_profiled"] + target_feature_profiled = phase_two["target_feature_profiled"] + profile_overrides = { + "executor_queue_ms": max(0.0, (cpu_stage.prepare_start - cpu_stage.prepare_submit_at) * 1000.0), + "prepare_admission_wait_ms": cpu_stage.prepare_admission_wait_ms, + "prepare_submit_ts": float(cpu_stage.prepare_submit_at), + "prepare_cpu_start_ts": float(cpu_stage.prepare_start), + "prepare_cpu_done_ts": float( + max(cpu_stage.prompt_cpu_profiled.finished_at, cpu_stage.target_cpu_profiled.finished_at) + ), + "prompt_text_cpu_start_ts": float(cpu_stage.prompt_cpu_profiled.started_at), + "prompt_text_cpu_end_ts": float(cpu_stage.prompt_cpu_profiled.finished_at), + "text_cpu_start_ts": float(cpu_stage.target_cpu_profiled.started_at), + "text_cpu_end_ts": float(cpu_stage.target_cpu_profiled.finished_at), + "executor_run_wall_ms": max(0.0, (time.perf_counter() - cpu_stage.prepare_start) * 1000.0), + "text_feature_pair_ms": float(phase_two["phase_wall_ms"]), + "g2pw_pair_ms": float(phase_one["g2pw_pair_ms"]), + "prompt_text_g2pw_queue_ms": prompt_g2pw_profiled.queue_ms, + "prompt_text_g2pw_run_ms": prompt_g2pw_profiled.run_ms, + "prompt_text_g2pw_prepare_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float((prompt_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "text_g2pw_queue_ms": target_g2pw_profiled.queue_ms, + "text_g2pw_run_ms": target_g2pw_profiled.run_ms, + "text_g2pw_prepare_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float((target_g2pw_profiled.profile or {}).get("g2pw_post_ms", 0.0)), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": 0.0, + "prompt_text_parallel_future_finish_after_submit_ms": 0.0, + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "prompt_text_cpu_queue_ms": cpu_stage.prompt_cpu_profiled.queue_ms, + "prompt_text_cpu_run_ms": cpu_stage.prompt_cpu_profiled.run_ms, + "prompt_text_cpu_admission_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "prompt_text_cpu_backpressure_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "prompt_text_cpu_capacity_wait_ms": float( + (cpu_stage.prompt_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "prompt_text_feature_queue_ms": prompt_feature_profiled.queue_ms, + "prompt_text_feature_run_ms": prompt_feature_profiled.run_ms, + "text_cpu_queue_ms": cpu_stage.target_cpu_profiled.queue_ms, + "text_cpu_run_ms": cpu_stage.target_cpu_profiled.run_ms, + "text_cpu_admission_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_admission_wait_ms", 0.0) + ), + "text_cpu_backpressure_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_backpressure_wait_ms", 0.0) + ), + "text_cpu_capacity_wait_ms": float( + (cpu_stage.target_cpu_profiled.profile or {}).get("text_cpu_capacity_wait_ms", 0.0) + ), + "text_feature_queue_ms": target_feature_profiled.queue_ms, + "text_feature_run_ms": target_feature_profiled.run_ms, + "ref_audio_task_queue_ms": ref_audio_profiled.queue_ms, + "ref_audio_task_run_ms": ref_audio_profiled.run_ms, + "worker_prepare_inflight_on_enter": float(cpu_stage.current_inflight), + "worker_prepare_peak_inflight": float(cpu_stage.peak_inflight), + } + if extra_profile: + profile_overrides.update({key: float(value) for key, value in extra_profile.items()}) + ref_audio_bundle = dict(ref_audio_profiled.result) + ref_audio_profile = dict(ref_audio_bundle.get("profile", {})) + if ref_spec_result is not None: + refer_spec, ref_spec_profiled = ref_spec_result + ref_audio_bundle["refer_spec"] = refer_spec + ref_audio_profile.update( + { + "ref_spec_wait_ms": float(ref_spec_profiled.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(ref_spec_profiled.get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(ref_spec_profiled.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(ref_spec_profiled.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(ref_spec_profiled.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(ref_spec_profiled.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(ref_spec_profiled.get("ref_spec_post_resample_ms", 0.0)), + } + ) + else: + ref_audio_bundle["refer_spec"] = None + ref_audio_profile.setdefault("ref_spec_wait_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_to_device_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_main_resample_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_norm_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_spectrogram_ms", 0.0) + ref_audio_profile.setdefault("ref_spec_post_resample_ms", 0.0) + ref_audio_bundle["profile"] = ref_audio_profile + state = build_request_state_from_parts( + tts=self.tts, + spec=cpu_stage.spec, + prompt_text=cpu_stage.prompt_text, + text=cpu_stage.text, + prompt_result=prompt_feature_profiled.result, + target_result=target_feature_profiled.result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=cpu_stage.prepare_start, + prepare_sync_start=cpu_stage.prepare_start, + profile_overrides=profile_overrides, + ) + prepare_exec_finished_at = time.perf_counter() + state.prepare_profile["executor_run_wall_ms"] = max(0.0, (prepare_exec_finished_at - cpu_stage.prepare_start) * 1000.0) + return state, cpu_stage.prepare_start, prepare_exec_finished_at + + async def prepare_ref_spec_stages_async( + self, + phase_ones: list[Dict[str, Any]], + ) -> list[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + async def _one(phase_one: Dict[str, Any]): + ref_audio_profiled = phase_one["ref_audio_profiled"] + raw_audio = ref_audio_profiled.result["raw_audio"] + raw_sr = int(ref_audio_profiled.result["raw_sr"]) + profiled = await self._run_ref_spec_stage(raw_audio, raw_sr) + refer_spec, profile = profiled.result + merged_profile = dict(profile) + merged_profile["ref_spec_wait_ms"] = float(profiled.queue_ms) + merged_profile["ref_spec_ms"] = float(profiled.run_ms) + return refer_spec, merged_profile + + if not phase_ones: + return [] + return list(await asyncio.gather(*[_one(phase_one) for phase_one in phase_ones], return_exceptions=True)) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + refer_spec, profile = ref_spec_result + state.refer_spec = refer_spec + state.prepare_profile["ref_spec_wait_ms"] = float(profile.get("ref_spec_wait_ms", 0.0)) + state.prepare_profile["ref_spec_ms"] = float(profile.get("ref_spec_ms", 0.0)) + state.prepare_profile["ref_spec_to_device_ms"] = float(profile.get("ref_spec_to_device_ms", 0.0)) + state.prepare_profile["ref_spec_main_resample_ms"] = float(profile.get("ref_spec_main_resample_ms", 0.0)) + state.prepare_profile["ref_spec_norm_ms"] = float(profile.get("ref_spec_norm_ms", 0.0)) + state.prepare_profile["ref_spec_spectrogram_ms"] = float(profile.get("ref_spec_spectrogram_ms", 0.0)) + state.prepare_profile["ref_spec_post_resample_ms"] = float(profile.get("ref_spec_post_resample_ms", 0.0)) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[tuple[T2SRequestState, float, float] | Exception]: + if not cpu_stages: + return [] + if len(cpu_stages) == 1: + single_stage = cpu_stages[0] + try: + return [await self.prepare_gpu_stage_profiled_async(single_stage)] + except Exception as exc: # noqa: PERF203 + return [exc] + + phase_one_started_at = time.perf_counter() + phase_one_results = await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + phase_one_finished_at = time.perf_counter() + phase_one_wall_ms = max(0.0, (phase_one_finished_at - phase_one_started_at) * 1000.0) + + outputs: list[tuple[T2SRequestState, float, float] | Exception | None] = [None] * len(cpu_stages) + pending_phase_two: list[tuple[int, PreparedCpuStage, Dict[str, Any]]] = [] + for index, (cpu_stage, phase_one) in enumerate(zip(cpu_stages, phase_one_results)): + if isinstance(phase_one, Exception): + outputs[index] = phase_one + self._release_split_stage_slot() + continue + pending_phase_two.append((index, cpu_stage, phase_one)) + + phase_two_started_at = time.perf_counter() + phase_two_results = await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for _, cpu_stage, phase_one in pending_phase_two], + return_exceptions=True, + ) + phase_two_finished_at = time.perf_counter() + phase_two_wall_ms = max(0.0, (phase_two_finished_at - phase_two_started_at) * 1000.0) + + for (index, cpu_stage, phase_one), phase_two in zip(pending_phase_two, phase_two_results): + try: + if isinstance(phase_two, Exception): + outputs[index] = phase_two + continue + outputs[index] = self._build_gpu_prepare_result( + cpu_stage, + phase_one, + phase_two, + extra_profile={ + "engine_prepare_audio_phase_mode": 1.0, + "engine_prepare_audio_phase_wall_ms": float(phase_one_wall_ms), + "engine_prepare_audio_phase_batch_size": float(len(cpu_stages)), + "engine_prepare_text_phase_wall_ms": float(phase_two_wall_ms), + "engine_prepare_text_phase_batch_size": float(len(pending_phase_two)), + }, + ) + except Exception as exc: # noqa: PERF203 + outputs[index] = exc + finally: + self._release_split_stage_slot() + + return [item if item is not None else RuntimeError("prepare batch result missing") for item in outputs] + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: list[PreparedCpuStage], + ) -> list[Dict[str, Any] | Exception]: + if not cpu_stages: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_one(cpu_stage) for cpu_stage in cpu_stages], + return_exceptions=True, + ) + ) + + async def prepare_gpu_text_phases_async( + self, + items: list[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> list[Dict[str, Any] | Exception]: + if not items: + return [] + return list( + await asyncio.gather( + *[self._prepare_gpu_phase_two(cpu_stage, phase_one) for cpu_stage, phase_one in items], + return_exceptions=True, + ) + ) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self._build_gpu_prepare_result(cpu_stage, phase_one, phase_two, extra_profile=extra_profile) + finally: + self._release_split_stage_slot() + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + cpu_stage = await self.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + return await self.prepare_gpu_stage_profiled_async(cpu_stage) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py new file mode 100644 index 00000000..4628a2a2 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py @@ -0,0 +1,382 @@ +import asyncio +import os +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Tuple + +import torch +import torchaudio + + +REF_AUDIO_MIN_SAMPLES_16K = 48000 +REF_AUDIO_MAX_SAMPLES_16K = 160000 +_RESAMPLE_CACHE_LOCK = threading.Lock() +_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {} +_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {} + + +def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample: + device_key = str(device) + key = (int(orig_sr), int(target_sr), device_key) + with _RESAMPLE_CACHE_LOCK: + transform = _RESAMPLE_CACHE.get(key) + if transform is None: + transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key) + _RESAMPLE_CACHE[key] = transform + return transform + + +def _get_resample_stream(device: str) -> torch.cuda.Stream: + device_key = str(device) + with _RESAMPLE_CACHE_LOCK: + stream = _RESAMPLE_STREAM_CACHE.get(device_key) + if stream is None: + stream = torch.cuda.Stream(device=device_key) + _RESAMPLE_STREAM_CACHE[device_key] = stream + return stream + + +def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor: + resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu" + if resample_device not in {"cpu", "cuda"}: + resample_device = "cpu" + if resample_device == "cuda" and not torch.cuda.is_available(): + resample_device = "cpu" + wav_mono = raw_audio + if wav_mono.dim() == 2 and wav_mono.shape[0] != 1: + wav_mono = wav_mono.mean(0, keepdim=True) + if resample_device == "cuda": + stream = _get_resample_stream(resample_device) + with torch.cuda.stream(stream): + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() + stream.synchronize() + wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous() + else: + wav16k = wav_mono.to(dtype=torch.float32, device=resample_device) + if raw_sr != 16000: + wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k) + wav16k = wav16k.squeeze(0).contiguous() + if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K: + raise OSError("参考音频在3~10秒范围外,请更换!") + if zero_wav_samples > 0: + wav16k = torch.cat( + [wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)], + dim=0, + ) + return wav16k.contiguous() + + +def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor: + if conv1d is None: + return input_lengths.to(dtype=torch.long) + kernel_size = int(conv1d.kernel_size[0]) + stride = int(conv1d.stride[0]) + padding = int(conv1d.padding[0]) + dilation = int(conv1d.dilation[0]) + output_lengths = torch.div( + input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1, + stride, + rounding_mode="floor", + ) + 1 + return torch.clamp(output_lengths, min=0).to(dtype=torch.long) + + +@dataclass +class RefSemanticTask: + raw_audio: torch.Tensor + raw_sr: int + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + batch_popped_at: float = 0.0 + done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None + result_prompt_semantic: torch.Tensor | None = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareRefSemanticBatchWorker: + def __init__( + self, + ssl_model, + vits_model, + device, + is_half: bool, + zero_wav_samples: int, + stage_limiter=None, + batch_window_ms: int = 5, + max_batch_items: int = 8, + max_batch_samples: int = 960000, + ): + self.ssl_model = ssl_model + self.vits_model = vits_model + self.device = device + self.is_half = bool(is_half) + self.zero_wav_samples = max(0, int(zero_wav_samples)) + self.stage_limiter = stage_limiter + self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0) + self.max_batch_items = max(1, int(max_batch_items)) + self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples)) + + self.condition = threading.Condition() + self.pending_tasks: Deque[RefSemanticTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.total_batches = 0 + self.active_batch_size = 0 + self.active_batch_peak = 0 + self.active_batch_samples = 0 + self.active_batch_samples_peak = 0 + self.worker_thread = threading.Thread( + target=self._run_loop, + name="prepare-ref-semantic-batch-worker", + daemon=True, + ) + self.worker_thread.start() + + def _estimate_task_samples(self, task: RefSemanticTask) -> int: + raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0 + base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr)))) + return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples + + def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr)) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.result_prompt_semantic is not None + return task.result_prompt_semantic, dict(task.profile) + + async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = RefSemanticTask( + raw_audio=raw_audio, + raw_sr=int(raw_sr), + done_loop=loop, + done_future=loop.create_future(), + ) + with self.condition: + self.pending_tasks.append(task) + self.total_submitted += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + return await task.done_future + + @staticmethod + def _resolve_done_future(task: RefSemanticTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + assert task.result_prompt_semantic is not None + task.done_future.set_result((task.result_prompt_semantic, dict(task.profile))) + + def _notify_task_done(self, task: RefSemanticTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + + def snapshot(self) -> Dict[str, int]: + with self.condition: + return { + "pending": len(self.pending_tasks), + "pending_peak": self.pending_peak, + "total_submitted": self.total_submitted, + "total_finished": self.total_finished, + "total_batches": self.total_batches, + "active_batch_size": self.active_batch_size, + "active_batch_peak": self.active_batch_peak, + "active_batch_samples": self.active_batch_samples, + "active_batch_samples_peak": self.active_batch_samples_peak, + "batch_window_ms": int(self.batch_window_s * 1000.0), + "max_batch_items": self.max_batch_items, + "max_batch_samples": self.max_batch_samples, + } + + def _collect_batch(self) -> tuple[List[RefSemanticTask], float]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + + first_task = self.pending_tasks.popleft() + first_task.batch_popped_at = time.perf_counter() + batch: List[RefSemanticTask] = [first_task] + batch_samples = self._estimate_task_samples(batch[0]) + deadline = time.perf_counter() + self.batch_window_s + + while len(batch) < self.max_batch_items: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not self.pending_tasks: + self.condition.wait(timeout=remaining) + continue + next_task = self.pending_tasks[0] + next_samples = self._estimate_task_samples(next_task) + if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples: + break + popped_task = self.pending_tasks.popleft() + popped_task.batch_popped_at = time.perf_counter() + batch.append(popped_task) + batch_samples += next_samples + + self.active_batch_size = len(batch) + self.active_batch_samples = batch_samples + if self.active_batch_size > self.active_batch_peak: + self.active_batch_peak = self.active_batch_size + if self.active_batch_samples > self.active_batch_samples_peak: + self.active_batch_samples_peak = self.active_batch_samples + return batch, time.perf_counter() + + def _finalize_batch(self, batch: List[RefSemanticTask]) -> None: + with self.condition: + self.active_batch_size = 0 + self.active_batch_samples = 0 + self.total_batches += 1 + self.total_finished += len(batch) + + def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor: + model = self.ssl_model.model + if hasattr(model, "_get_feature_vector_attention_mask"): + feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask) + return feature_mask.to(dtype=torch.long).sum(dim=1) + raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1) + if hasattr(model, "_get_feat_extract_output_lengths"): + return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long) + return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device) + + @torch.inference_mode() + def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None: + batch_started = time.perf_counter() + prepared_start = time.perf_counter() + prepared_wavs = [ + prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch + ] + cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0 + wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long) + batch_samples = int(wav_lengths.sum().item()) + max_wav_len = int(wav_lengths.max().item()) + + pack_start = time.perf_counter() + input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32) + attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long) + for batch_index, wav in enumerate(prepared_wavs): + wav_len = int(wav.shape[0]) + input_values_cpu[batch_index, :wav_len] = wav + attention_mask_cpu[batch_index, :wav_len] = 1 + pack_ms = (time.perf_counter() - pack_start) * 1000.0 + + limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0} + h2d_ms = 0.0 + ssl_forward_ms = 0.0 + hidden_length_ms = 0.0 + extract_latent_ms = 0.0 + if self.stage_limiter is None: + h2d_start = time.perf_counter() + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() + codes = self.vits_model.extract_latent(hubert_feature) + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 + else: + with self.stage_limiter.enter() as limiter_stats: + h2d_start = time.perf_counter() + input_values = input_values_cpu.to(self.device) + attention_mask = attention_mask_cpu.to(self.device) + if self.is_half: + input_values = input_values.half() + h2d_ms = (time.perf_counter() - h2d_start) * 1000.0 + ssl_start = time.perf_counter() + outputs = self.ssl_model.model(input_values, attention_mask=attention_mask) + ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0 + hubert_feature = outputs["last_hidden_state"].transpose(1, 2) + hidden_length_start = time.perf_counter() + hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1])) + hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0 + latent_start = time.perf_counter() + codes = self.vits_model.extract_latent(hubert_feature) + extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0 + forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms) + + code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None)) + scatter_start = time.perf_counter() + for batch_index, task in enumerate(batch): + try: + code_len = int(code_lengths[batch_index].item()) + task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone() + worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0) + batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0) + stage_limiter_wait_ms = float(limiter_stats["wait_ms"]) + task.profile = { + "prompt_semantic_wait_ms": worker_queue_wait_ms + + batch_collect_wait_ms + + stage_limiter_wait_ms, + "prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms, + "prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms, + "prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms, + "prompt_semantic_batch_dispatch_delay_ms": max( + 0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0 + ), + "prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms), + "prompt_semantic_pack_ms": float(pack_ms), + "prompt_semantic_h2d_ms": float(h2d_ms), + "prompt_semantic_ssl_forward_ms": float(ssl_forward_ms), + "prompt_semantic_hidden_length_ms": float(hidden_length_ms), + "prompt_semantic_extract_latent_ms": float(extract_latent_ms), + "prompt_semantic_forward_ms": float(forward_ms), + "prompt_semantic_scatter_ms": 0.0, + "prompt_semantic_calls": 1.0, + "prompt_semantic_stage_slots": float(limiter_stats["slots"]), + "prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]), + "prompt_semantic_batch_size": float(len(batch)), + "prompt_semantic_batch_samples": float(batch_samples), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + scatter_ms = (time.perf_counter() - scatter_start) * 1000.0 + for task in batch: + if task.result_prompt_semantic is not None: + task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms) + self._notify_task_done(task) + + def _run_loop(self) -> None: + while True: + batch, batch_collected_at = self._collect_batch() + try: + self._run_batch(batch, batch_collected_at) + except Exception as exc: # noqa: PERF203 + for task in batch: + task.error = exc + self._notify_task_done(task) + finally: + self._finalize_batch(batch) diff --git a/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py new file mode 100644 index 00000000..b7c985f0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py @@ -0,0 +1,215 @@ +import asyncio +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, Tuple + + +@dataclass +class TextCpuTask: + text: str + language: str + task_id: str = field(default_factory=lambda: uuid.uuid4().hex) + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + admission_wait_ms: float = 0.0 + backpressure_wait_ms: float = 0.0 + capacity_wait_ms: float = 0.0 + pending_depth_on_enqueue: int = 0 + done_event: threading.Event = field(default_factory=threading.Event) + done_loop: asyncio.AbstractEventLoop | None = None + done_future: asyncio.Future | None = None + result: Any = None + error: Exception | None = None + profile: Dict[str, float] = field(default_factory=dict) + + +class PrepareTextCpuWorker: + def __init__( + self, + process_fn: Callable[[str, str], Any], + worker_count: int, + max_pending_tasks: int = 0, + admission_poll_ms: int = 1, + admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None, + ) -> None: + self.process_fn = process_fn + self.worker_count = max(1, int(worker_count)) + self.max_pending_tasks = max(0, int(max_pending_tasks)) + self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0) + self.admission_controller = admission_controller + + self.condition = threading.Condition() + self.pending_tasks: Deque[TextCpuTask] = deque() + self.pending_peak = 0 + self.total_submitted = 0 + self.total_finished = 0 + self.active_workers = 0 + self.active_workers_peak = 0 + self.admission_wait_total_ms = 0.0 + self.admission_wait_peak_ms = 0.0 + self.backpressure_wait_total_ms = 0.0 + self.backpressure_wait_peak_ms = 0.0 + self.capacity_wait_total_ms = 0.0 + self.capacity_wait_peak_ms = 0.0 + self.backpressure_blocked_total = 0 + + self.worker_threads = [ + threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True) + for index in range(self.worker_count) + ] + for thread in self.worker_threads: + thread.start() + + def _can_enqueue_locked(self) -> bool: + if self.max_pending_tasks <= 0: + return True + return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks + + def _get_admission_state(self) -> Dict[str, float | int | bool]: + if self.admission_controller is None: + return {"blocked": False} + try: + state = dict(self.admission_controller() or {}) + except Exception: + return {"blocked": False} + state["blocked"] = bool(state.get("blocked", False)) + return state + + def _record_enqueue_locked( + self, + task: TextCpuTask, + *, + admission_wait_ms: float, + backpressure_wait_ms: float, + capacity_wait_ms: float, + ) -> None: + task.admission_wait_ms = float(max(0.0, admission_wait_ms)) + task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms)) + task.capacity_wait_ms = float(max(0.0, capacity_wait_ms)) + task.enqueued_at = time.perf_counter() + task.pending_depth_on_enqueue = int(len(self.pending_tasks)) + self.pending_tasks.append(task) + self.total_submitted += 1 + self.admission_wait_total_ms += task.admission_wait_ms + self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms) + self.backpressure_wait_total_ms += task.backpressure_wait_ms + self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms) + self.capacity_wait_total_ms += task.capacity_wait_ms + self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms) + if task.backpressure_wait_ms > 0.0: + self.backpressure_blocked_total += 1 + if len(self.pending_tasks) > self.pending_peak: + self.pending_peak = len(self.pending_tasks) + self.condition.notify_all() + + async def _enqueue_task_async(self, task: TextCpuTask) -> None: + admission_started = time.perf_counter() + backpressure_wait_ms = 0.0 + capacity_wait_ms = 0.0 + while True: + loop_start = time.perf_counter() + admission_state = self._get_admission_state() + blocked = bool(admission_state.get("blocked", False)) + with self.condition: + if not blocked and self._can_enqueue_locked(): + self._record_enqueue_locked( + task, + admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0, + backpressure_wait_ms=backpressure_wait_ms, + capacity_wait_ms=capacity_wait_ms, + ) + return + await asyncio.sleep(self.admission_poll_s) + waited_ms = (time.perf_counter() - loop_start) * 1000.0 + if blocked: + backpressure_wait_ms += waited_ms + else: + capacity_wait_ms += waited_ms + + def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + task = TextCpuTask(text=str(text), language=str(language)) + asyncio.run(self._enqueue_task_async(task)) + task.done_event.wait() + if task.error is not None: + raise task.error + return task.result, dict(task.profile) + + async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]: + loop = asyncio.get_running_loop() + task = TextCpuTask( + text=str(text), + language=str(language), + done_loop=loop, + done_future=loop.create_future(), + ) + await self._enqueue_task_async(task) + return await task.done_future + + @staticmethod + def _resolve_done_future(task: TextCpuTask) -> None: + if task.done_future is None or task.done_future.done(): + return + if task.error is not None: + task.done_future.set_exception(task.error) + return + task.done_future.set_result((task.result, dict(task.profile))) + + def _notify_task_done(self, task: TextCpuTask) -> None: + task.done_event.set() + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_done_future, task) + except RuntimeError: + pass + + def snapshot(self) -> Dict[str, int | float]: + with self.condition: + return { + "worker_count": int(self.worker_count), + "pending": int(len(self.pending_tasks)), + "pending_peak": int(self.pending_peak), + "active_workers": int(self.active_workers), + "active_workers_peak": int(self.active_workers_peak), + "total_submitted": int(self.total_submitted), + "total_finished": int(self.total_finished), + "max_pending_tasks": int(self.max_pending_tasks), + "admission_wait_total_ms": float(self.admission_wait_total_ms), + "admission_wait_peak_ms": float(self.admission_wait_peak_ms), + "backpressure_wait_total_ms": float(self.backpressure_wait_total_ms), + "backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms), + "capacity_wait_total_ms": float(self.capacity_wait_total_ms), + "capacity_wait_peak_ms": float(self.capacity_wait_peak_ms), + "backpressure_blocked_total": int(self.backpressure_blocked_total), + } + + def _run_loop(self) -> None: + while True: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + task = self.pending_tasks.popleft() + self.active_workers += 1 + self.active_workers_peak = max(self.active_workers_peak, self.active_workers) + started_at = time.perf_counter() + try: + task.result = self.process_fn(task.text, task.language) + task.profile = { + "text_cpu_admission_wait_ms": float(task.admission_wait_ms), + "text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms), + "text_cpu_capacity_wait_ms": float(task.capacity_wait_ms), + "text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0), + "text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue), + "text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0), + } + except Exception as exc: # noqa: PERF203 + task.error = exc + finally: + with self.condition: + self.active_workers = max(0, self.active_workers - 1) + self.total_finished += 1 + self.condition.notify_all() + self._notify_task_done(task) diff --git a/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py new file mode 100644 index 00000000..a4d462d0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py @@ -0,0 +1,1285 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import os +from pathlib import Path +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F + +from AR.models.utils import logits_to_probs, make_pad_mask_left, multinomial_sample_one_no_sync, sample + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +@dataclass +class SchedulerRequestSpec: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + aux_ref_audio_paths: List[str] = field(default_factory=list) + ready_step: int = 0 + + +@dataclass +class T2SRequestState: + request_id: str + ref_audio_path: Path + prompt_text: str + prompt_lang: str + text: str + text_lang: str + norm_prompt_text: str + norm_text: str + phones: torch.LongTensor + prompt_phones: torch.LongTensor + all_phones: torch.LongTensor + all_bert_features: torch.Tensor + prompt_semantic: torch.LongTensor + refer_spec: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] + aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] + raw_audio: torch.Tensor + raw_sr: int + top_k: int + top_p: float + temperature: float + repetition_penalty: float + early_stop_num: int + ready_step: int + prepare_profile: Dict[str, float] + + +@dataclass +class T2SRunningRequest: + state: T2SRequestState + y_sequence: torch.LongTensor + prefix_len: int + decode_attn_mask: Optional[torch.Tensor] + k_cache: List[torch.Tensor] + v_cache: List[torch.Tensor] + step_idx: int + + +@dataclass +class T2SFinishedItem: + request_id: str + semantic_tokens: torch.LongTensor + finish_idx: int + finish_reason: str + + +@dataclass +class T2SActiveBatch: + request_ids: List[str] + states: List[T2SRequestState] + x: Optional[torch.Tensor] + x_lens: Optional[torch.LongTensor] + y_sequences: List[torch.LongTensor] + prefix_lens: torch.LongTensor + xy_pos: torch.Tensor + key_padding_mask: Optional[torch.Tensor] + prefill_attn_mask: Optional[torch.Tensor] + decode_attn_mask: Optional[torch.Tensor] + k_cache: Optional[List[torch.Tensor]] + v_cache: Optional[List[torch.Tensor]] + kv_lens: Optional[torch.LongTensor] + step_indices: torch.LongTensor + prefill_done: bool + + +@dataclass +class PreparedTextFeatures: + phones: List[int] + bert_features: torch.Tensor + norm_text: str + profile: Dict[str, float] + total_ms: float + cpu_preprocess_ms: float + + +def build_empty_text_features( + *, + feature_dim: int = 1024, + dtype: torch.dtype = torch.float32, +) -> PreparedTextFeatures: + return PreparedTextFeatures( + phones=[], + bert_features=torch.empty((int(feature_dim), 0), dtype=dtype), + norm_text="", + profile={"cpu_preprocess_ms": 0.0, "bert_total_ms": 0.0}, + total_ms=0.0, + cpu_preprocess_ms=0.0, + ) + + +def normalize_sentence(text: str, language: str) -> str: + text = text.strip("\n").strip() + if not text: + return text + if text[-1] not in {",", ".", "?", "!", ",", "。", "?", "!", "…", ";", ";", ":"}: + text += "。" if language != "en" else "." + return text + + +@torch.inference_mode() +def prepare_text_features( + tts: Any, + text: str, + language: str, +) -> PreparedTextFeatures: + device = tts.configs.device + profile: Dict[str, float] = {} + branch_start = time.perf_counter() + _sync_device(device) + cpu_start = time.perf_counter() + prepared_segments = tts.prepare_text_segments(text, language) + _sync_device(device) + cpu_preprocess_ms = (time.perf_counter() - cpu_start) * 1000.0 + profile["cpu_preprocess_ms"] = float(cpu_preprocess_ms) + bert_start = time.perf_counter() + phones, bert_features, norm_text = tts.build_text_features_from_segments(prepared_segments, profile=profile) + _sync_device(device) + profile["bert_total_ms"] = (time.perf_counter() - bert_start) * 1000.0 + total_ms = (time.perf_counter() - branch_start) * 1000.0 + return PreparedTextFeatures( + phones=phones, + bert_features=bert_features, + norm_text=norm_text, + profile=profile, + total_ms=float(total_ms), + cpu_preprocess_ms=float(cpu_preprocess_ms), + ) + + +@torch.inference_mode() +def build_request_state_from_parts( + tts: Any, + spec: SchedulerRequestSpec, + prompt_text: str, + text: str, + prompt_result: PreparedTextFeatures, + target_result: PreparedTextFeatures, + ref_audio_bundle: Dict[str, Any], + prepare_start: float, + prepare_sync_start: float, + profile_overrides: Optional[Dict[str, float]] = None, +) -> T2SRequestState: + device = tts.configs.device + _sync_device(device) + ref_audio_bundle_ms = float(ref_audio_bundle.get("profile", {}).get("bundle_total_ms", 0.0)) + bundle_profile = ref_audio_bundle.get("profile", {}) + prompt_semantic = ref_audio_bundle["prompt_semantic"].long() + refer_spec_value = ref_audio_bundle.get("refer_spec") + if refer_spec_value in [None, ()]: + spec_audio, audio_16k = None, None + else: + spec_audio, audio_16k = refer_spec_value + aux_refer_specs: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [] + for aux_ref_audio_path in list(getattr(spec, "aux_ref_audio_paths", []) or []): + if aux_ref_audio_path in [None, ""]: + continue + if not os.path.exists(str(aux_ref_audio_path)): + continue + aux_spec_audio, aux_audio_16k, _, _ = tts.extract_ref_spec(str(aux_ref_audio_path)) + aux_refer_specs.append((aux_spec_audio, aux_audio_16k)) + raw_audio = ref_audio_bundle["raw_audio"] + raw_sr = int(ref_audio_bundle["raw_sr"]) + prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms)) + ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0)) + audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0)) + + _sync_device(device) + tensorize_start = time.perf_counter() + phones_tensor = torch.LongTensor(target_result.phones).to(tts.configs.device) + prompt_phones_tensor = torch.LongTensor(prompt_result.phones).to(tts.configs.device) + all_phones = torch.LongTensor(prompt_result.phones + target_result.phones).to(tts.configs.device) + prompt_bert_features = prompt_result.bert_features.to(dtype=tts.precision, device=tts.configs.device) + target_bert_features = target_result.bert_features.to(dtype=tts.precision, device=tts.configs.device) + all_bert_features = torch.cat([prompt_bert_features, target_bert_features], dim=1) + _sync_device(device) + tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0 + + prepare_profile = { + "prompt_text_features_ms": float(prompt_result.total_ms), + "text_features_ms": float(target_result.total_ms), + "prompt_text_cpu_preprocess_ms": float(prompt_result.cpu_preprocess_ms), + "text_cpu_preprocess_ms": float(target_result.cpu_preprocess_ms), + "prompt_text_bert_wait_ms": float(prompt_result.profile.get("bert_wait_ms", 0.0)), + "prompt_text_bert_admission_wait_ms": float(prompt_result.profile.get("bert_admission_wait_ms", 0.0)), + "prompt_text_bert_queue_wait_ms": float(prompt_result.profile.get("bert_queue_wait_ms", 0.0)), + "prompt_text_bert_batch_collect_wait_ms": float(prompt_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "prompt_text_bert_forward_ms": float(prompt_result.profile.get("bert_forward_ms", 0.0)), + "prompt_text_bert_tokenize_ms": float(prompt_result.profile.get("bert_tokenize_ms", 0.0)), + "prompt_text_bert_scatter_ms": float(prompt_result.profile.get("bert_scatter_ms", 0.0)), + "prompt_text_bert_calls": float(prompt_result.profile.get("bert_calls", 0.0)), + "prompt_text_bert_stage_slots": float(prompt_result.profile.get("bert_stage_slots", 0.0)), + "prompt_text_bert_stage_inflight_peak": float(prompt_result.profile.get("bert_stage_inflight_peak", 0.0)), + "prompt_text_bert_batch_size_peak": float(prompt_result.profile.get("bert_batch_size_peak", 0.0)), + "prompt_text_bert_batch_tokens_peak": float(prompt_result.profile.get("bert_batch_tokens_peak", 0.0)), + "prompt_text_bert_pending_depth_on_enqueue_peak": float( + prompt_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "prompt_text_bert_pending_depth_on_collect_peak": float( + prompt_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "prompt_text_bert_high_pressure_mode_peak": float( + prompt_result.profile.get("bert_high_pressure_mode_peak", 0.0) + ), + "prompt_text_bert_batch_window_ms": float(prompt_result.profile.get("bert_batch_window_ms", 0.0)), + "prompt_text_g2pw_total_ms": float(prompt_result.profile.get("g2pw_total_ms", 0.0)), + "prompt_text_g2pw_prepare_ms": float(prompt_result.profile.get("g2pw_prepare_ms", 0.0)), + "prompt_text_g2pw_predict_ms": float(prompt_result.profile.get("g2pw_predict_ms", 0.0)), + "prompt_text_g2pw_post_ms": float(prompt_result.profile.get("g2pw_post_ms", 0.0)), + "prompt_text_g2pw_runtime_total_ms": float(prompt_result.profile.get("g2pw_runtime_total_ms", 0.0)), + "prompt_text_g2pw_runtime_queue_wait_ms": float( + prompt_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0) + ), + "prompt_text_g2pw_runtime_collect_wait_ms": float( + prompt_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0) + ), + "prompt_text_g2pw_runtime_run_ms": float(prompt_result.profile.get("g2pw_runtime_run_ms", 0.0)), + "prompt_text_g2pw_runtime_batch_rows_peak": float( + prompt_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0) + ), + "prompt_text_g2pw_runtime_batch_requests_peak": float( + prompt_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0) + ), + "prompt_text_parallel_future_wait_ms": 0.0, + "prompt_text_parallel_future_executor_queue_ms": 0.0, + "prompt_text_parallel_future_run_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_finish_after_submit_ms": float(prompt_result.total_ms), + "prompt_text_parallel_future_queue_tail_after_target_ms": 0.0, + "prompt_text_parallel_future_run_tail_after_target_ms": 0.0, + "text_bert_wait_ms": float(target_result.profile.get("bert_wait_ms", 0.0)), + "text_bert_admission_wait_ms": float(target_result.profile.get("bert_admission_wait_ms", 0.0)), + "text_bert_queue_wait_ms": float(target_result.profile.get("bert_queue_wait_ms", 0.0)), + "text_bert_batch_collect_wait_ms": float(target_result.profile.get("bert_batch_collect_wait_ms", 0.0)), + "text_bert_forward_ms": float(target_result.profile.get("bert_forward_ms", 0.0)), + "text_bert_tokenize_ms": float(target_result.profile.get("bert_tokenize_ms", 0.0)), + "text_bert_scatter_ms": float(target_result.profile.get("bert_scatter_ms", 0.0)), + "text_bert_calls": float(target_result.profile.get("bert_calls", 0.0)), + "text_bert_stage_slots": float(target_result.profile.get("bert_stage_slots", 0.0)), + "text_bert_stage_inflight_peak": float(target_result.profile.get("bert_stage_inflight_peak", 0.0)), + "text_bert_batch_size_peak": float(target_result.profile.get("bert_batch_size_peak", 0.0)), + "text_bert_batch_tokens_peak": float(target_result.profile.get("bert_batch_tokens_peak", 0.0)), + "text_bert_pending_depth_on_enqueue_peak": float( + target_result.profile.get("bert_pending_depth_on_enqueue_peak", 0.0) + ), + "text_bert_pending_depth_on_collect_peak": float( + target_result.profile.get("bert_pending_depth_on_collect_peak", 0.0) + ), + "text_bert_high_pressure_mode_peak": float(target_result.profile.get("bert_high_pressure_mode_peak", 0.0)), + "text_bert_batch_window_ms": float(target_result.profile.get("bert_batch_window_ms", 0.0)), + "text_g2pw_total_ms": float(target_result.profile.get("g2pw_total_ms", 0.0)), + "text_g2pw_prepare_ms": float(target_result.profile.get("g2pw_prepare_ms", 0.0)), + "text_g2pw_predict_ms": float(target_result.profile.get("g2pw_predict_ms", 0.0)), + "text_g2pw_post_ms": float(target_result.profile.get("g2pw_post_ms", 0.0)), + "text_g2pw_runtime_total_ms": float(target_result.profile.get("g2pw_runtime_total_ms", 0.0)), + "text_g2pw_runtime_queue_wait_ms": float(target_result.profile.get("g2pw_runtime_queue_wait_ms", 0.0)), + "text_g2pw_runtime_collect_wait_ms": float(target_result.profile.get("g2pw_runtime_collect_wait_ms", 0.0)), + "text_g2pw_runtime_run_ms": float(target_result.profile.get("g2pw_runtime_run_ms", 0.0)), + "text_g2pw_runtime_batch_rows_peak": float(target_result.profile.get("g2pw_runtime_batch_rows_peak", 0.0)), + "text_g2pw_runtime_batch_requests_peak": float( + target_result.profile.get("g2pw_runtime_batch_requests_peak", 0.0) + ), + "text_feature_pair_ms": float(max(prompt_result.total_ms, target_result.total_ms)), + "text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)), + "audio_load_ms": audio_load_ms, + "audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)), + "audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)), + "audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)), + "prompt_semantic_ms": prompt_semantic_ms, + "prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_wait_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_wait_ms", 0.0)), + "prompt_semantic_cpu_prepare_slots": float(bundle_profile.get("prompt_semantic_cpu_prepare_slots", 0.0)), + "prompt_semantic_cpu_prepare_inflight_peak": float( + bundle_profile.get("prompt_semantic_cpu_prepare_inflight_peak", 0.0) + ), + "prompt_semantic_worker_queue_wait_ms": float( + bundle_profile.get("prompt_semantic_worker_queue_wait_ms", 0.0) + ), + "prompt_semantic_batch_collect_wait_ms": float( + bundle_profile.get("prompt_semantic_batch_collect_wait_ms", 0.0) + ), + "prompt_semantic_stage_limiter_wait_ms": float( + bundle_profile.get("prompt_semantic_stage_limiter_wait_ms", 0.0) + ), + "prompt_semantic_batch_dispatch_delay_ms": float( + bundle_profile.get("prompt_semantic_batch_dispatch_delay_ms", 0.0) + ), + "prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "prompt_semantic_pack_ms": float(bundle_profile.get("prompt_semantic_pack_ms", 0.0)), + "prompt_semantic_h2d_ms": float(bundle_profile.get("prompt_semantic_h2d_ms", 0.0)), + "prompt_semantic_ssl_forward_ms": float(bundle_profile.get("prompt_semantic_ssl_forward_ms", 0.0)), + "prompt_semantic_hidden_length_ms": float(bundle_profile.get("prompt_semantic_hidden_length_ms", 0.0)), + "prompt_semantic_extract_latent_ms": float(bundle_profile.get("prompt_semantic_extract_latent_ms", 0.0)), + "prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)), + "prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)), + "prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)), + "prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)), + "prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)), + "prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)), + "ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": ref_spec_ms, + "ref_spec_to_device_ms": float(bundle_profile.get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(bundle_profile.get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(bundle_profile.get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(bundle_profile.get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(bundle_profile.get("ref_spec_post_resample_ms", 0.0)), + "ref_audio_bundle_ms": ref_audio_bundle_ms, + "tensorize_ms": tensorize_ms, + "total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0, + "wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0, + } + if profile_overrides: + prepare_profile.update({key: float(value) for key, value in profile_overrides.items()}) + return T2SRequestState( + request_id=spec.request_id, + ref_audio_path=spec.ref_audio_path, + prompt_text=prompt_text, + prompt_lang=spec.prompt_lang, + text=text, + text_lang=spec.text_lang, + norm_prompt_text=prompt_result.norm_text, + norm_text=target_result.norm_text, + phones=phones_tensor, + prompt_phones=prompt_phones_tensor, + all_phones=all_phones, + all_bert_features=all_bert_features, + prompt_semantic=prompt_semantic, + refer_spec=(None if spec_audio is None else (spec_audio, audio_16k)), + aux_refer_specs=aux_refer_specs, + raw_audio=raw_audio, + raw_sr=raw_sr, + top_k=spec.top_k, + top_p=spec.top_p, + temperature=spec.temperature, + repetition_penalty=spec.repetition_penalty, + early_stop_num=spec.early_stop_num, + ready_step=spec.ready_step, + prepare_profile=prepare_profile, + ) + + +@torch.inference_mode() +def prepare_request_state( + tts: Any, + spec: SchedulerRequestSpec, +) -> T2SRequestState: + prepare_start = time.perf_counter() + prepare_sync_start = time.perf_counter() + prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang) + text = spec.text.strip("\n") + target_result = prepare_text_features(tts, text, spec.text_lang) + if target_result.phones is None: + raise ValueError(f"{spec.request_id} text preprocessing returned no phones") + if prompt_text in [None, ""]: + prompt_result = build_empty_text_features( + feature_dim=int(target_result.bert_features.shape[0]), + dtype=target_result.bert_features.dtype, + ) + else: + prompt_result = prepare_text_features(tts, prompt_text, spec.prompt_lang) + ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path)) + return build_request_state_from_parts( + tts=tts, + spec=spec, + prompt_text=prompt_text, + text=text, + prompt_result=prompt_result, + target_result=target_result, + ref_audio_bundle=ref_audio_bundle, + prepare_start=prepare_start, + prepare_sync_start=prepare_sync_start, + ) + + +def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor: + if hidden.shape[0] >= target_len: + return hidden + return F.pad(hidden, (0, 0, target_len - hidden.shape[0], 0), value=0) + + +def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: torch.device) -> None: + required_len = max_position + 1 + if model.ar_audio_position.pe is not None and model.ar_audio_position.pe.size(1) >= required_len: + if model.ar_audio_position.pe.dtype != dtype or model.ar_audio_position.pe.device != device: + model.ar_audio_position.pe = model.ar_audio_position.pe.to(dtype=dtype, device=device) + return + model.ar_audio_position.extend_pe( + torch.zeros(1, required_len, model.ar_audio_position.embedding_dim, device=device, dtype=dtype) + ) + + +def _pad_token_sequences( + token_sequences: Sequence[torch.LongTensor], +) -> Tuple[torch.LongTensor, torch.BoolTensor]: + if not token_sequences: + raise ValueError("token_sequences 不能为空") + device = token_sequences[0].device + max_len = max(int(sequence.shape[0]) for sequence in token_sequences) + padded = torch.zeros((len(token_sequences), max_len), dtype=token_sequences[0].dtype, device=device) + mask = torch.zeros((len(token_sequences), max_len), dtype=torch.bool, device=device) + for row_index, sequence in enumerate(token_sequences): + seq_len = int(sequence.shape[0]) + padded[row_index, :seq_len] = sequence + mask[row_index, :seq_len] = True + return padded, mask + + +def _sampling_group_key( + top_k: int, + top_p: float, + temperature: float, + repetition_penalty: float, + trim_eos: bool, +) -> Tuple[int, float, float, float, bool]: + return ( + int(top_k), + float(top_p), + float(temperature), + float(repetition_penalty), + bool(trim_eos), + ) + + +def _iter_contiguous_sampling_groups( + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> List[Tuple[Tuple[int, float, float, float, bool], List[int]]]: + groups: List[Tuple[Tuple[int, float, float, float, bool], List[int]]] = [] + if not sampling_keys: + return groups + current_key = sampling_keys[0] + current_indices: List[int] = [0] + for index in range(1, len(sampling_keys)): + key = sampling_keys[index] + if key == current_key: + current_indices.append(index) + continue + groups.append((current_key, current_indices)) + current_key = key + current_indices = [index] + groups.append((current_key, current_indices)) + return groups + + +def _uniform_sampling_group_key(active_batch: T2SActiveBatch) -> Optional[Tuple[int, float, float, float, bool]]: + if not active_batch.states: + return None + if active_batch.step_indices.numel() <= 0: + return None + first_step_index = int(active_batch.step_indices[0].item()) + if bool((active_batch.step_indices != first_step_index).any().item()): + return None + first_state = active_batch.states[0] + first_key = _sampling_group_key( + top_k=first_state.top_k, + top_p=first_state.top_p, + temperature=first_state.temperature, + repetition_penalty=first_state.repetition_penalty, + trim_eos=first_step_index < 11, + ) + for state in active_batch.states[1:]: + if ( + state.top_k != first_state.top_k + or state.top_p != first_state.top_p + or state.temperature != first_state.temperature + or state.repetition_penalty != first_state.repetition_penalty + ): + return None + return first_key + + +def _batched_sample_uniform( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_key: Tuple[int, float, float, float, bool], +) -> Tuple[torch.Tensor, torch.Tensor]: + top_k, top_p, temperature, repetition_penalty, trim_eos = sampling_key + sample_logits = logits[:, :-1] if trim_eos else logits + padded_histories, history_mask = _pad_token_sequences(histories) + probs = logits_to_probs( + logits=sample_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + sampled = multinomial_sample_one_no_sync(probs) + argmax_tokens = torch.argmax(sample_logits, dim=-1) + return sampled, argmax_tokens + + +def _batched_sample_by_group( + logits: torch.Tensor, + histories: Sequence[torch.LongTensor], + sampling_keys: Sequence[Tuple[int, float, float, float, bool]], +) -> Tuple[List[torch.Tensor], List[int]]: + sampled_list: List[Optional[torch.Tensor]] = [None] * len(histories) + argmax_list: List[Optional[int]] = [None] * len(histories) + for group_key, group_indices in _iter_contiguous_sampling_groups(sampling_keys): + top_k, top_p, temperature, repetition_penalty, trim_eos = group_key + index_tensor = torch.tensor(group_indices, dtype=torch.long, device=logits.device) + group_logits = torch.index_select(logits, dim=0, index=index_tensor) + if trim_eos: + group_logits = group_logits[:, :-1] + group_histories = [histories[index] for index in group_indices] + padded_histories, history_mask = _pad_token_sequences(group_histories) + probs = logits_to_probs( + logits=group_logits, + previous_tokens=padded_histories, + previous_token_mask=history_mask, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + ) + argmax_tokens = torch.argmax(group_logits, dim=-1) + for local_index, global_index in enumerate(group_indices): + sampled_list[global_index] = multinomial_sample_one_no_sync(probs[local_index : local_index + 1]) + argmax_list[global_index] = int(argmax_tokens[local_index].item()) + + return [item for item in sampled_list if item is not None], [int(item) for item in argmax_list if item is not None] + + +@torch.inference_mode() +def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch: + x_items: List[torch.Tensor] = [] + y_pos_items: List[torch.Tensor] = [] + x_lens: List[int] = [] + prefix_lens: List[int] = [] + y_sequences: List[torch.LongTensor] = [] + + for state in states: + text_emb = model.ar_text_embedding(state.all_phones.unsqueeze(0)) + bert_proj = model.bert_proj(state.all_bert_features.transpose(0, 1).unsqueeze(0)) + x_pos = model.ar_text_position(text_emb + bert_proj).squeeze(0) + y_emb = model.ar_audio_embedding(state.prompt_semantic.unsqueeze(0)) + y_pos = model.ar_audio_position(y_emb).squeeze(0) + x_items.append(x_pos) + y_pos_items.append(y_pos) + x_lens.append(x_pos.shape[0]) + prefix_lens.append(y_pos.shape[0]) + y_sequences.append(state.prompt_semantic.clone()) + + max_x_len = max(x_lens) + max_prefix_len = max(prefix_lens) + x_batch = torch.stack([_left_pad_hidden(item, max_x_len) for item in x_items], dim=0) + y_pos_batch = torch.stack([_left_pad_hidden(item, max_prefix_len) for item in y_pos_items], dim=0) + xy_pos = torch.cat([x_batch, y_pos_batch], dim=1) + + device = x_batch.device + x_lens_tensor = torch.LongTensor(x_lens).to(device) + prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device) + src_len = max_x_len + max_prefix_len + + x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len) + y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len) + key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool() + x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True) + y_mask = F.pad( + torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1), + (max_x_len, 0), + value=False, + ) + causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0) + attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1) + + return T2SActiveBatch( + request_ids=[state.request_id for state in states], + states=list(states), + x=x_batch, + x_lens=x_lens_tensor, + y_sequences=y_sequences, + prefix_lens=prefix_lens_tensor, + xy_pos=xy_pos, + key_padding_mask=key_padding_mask, + prefill_attn_mask=attn_mask, + decode_attn_mask=None, + k_cache=None, + v_cache=None, + kv_lens=None, + step_indices=torch.zeros((len(states),), dtype=torch.long, device=device), + prefill_done=False, + ) + + +def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> torch.Tensor: + last_tokens = torch.stack([seq[-1:] for seq in y_sequences], dim=0) + y_emb = model.ar_audio_embedding(last_tokens) + position_ids = torch.LongTensor([int(seq.shape[0] - 1) for seq in y_sequences]).to(y_emb.device) + _ensure_audio_pe(model, int(position_ids.max().item()), y_emb.dtype, y_emb.device) + pos_emb = model.ar_audio_position.pe[0].index_select(0, position_ids).unsqueeze(1) + return y_emb * model.ar_audio_position.x_scale + model.ar_audio_position.alpha * pos_emb.to( + dtype=y_emb.dtype, device=y_emb.device + ) + + +def _compact_cache_to_kv_lens( + cache: torch.Tensor, + kv_lens: torch.LongTensor, +) -> torch.Tensor: + target_len = int(kv_lens.max().item()) + if cache.shape[1] == target_len and torch.all(kv_lens == target_len).item(): + return cache + compacted = cache.new_zeros((cache.shape[0], target_len, cache.shape[2])) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + if kv_len <= 0: + continue + compacted[batch_index, -kv_len:, :] = cache[batch_index, -kv_len:, :] + return compacted + + +def _compact_decode_mask_to_kv_lens( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + target_len = int(kv_lens.max().item()) + 1 + if decode_attn_mask is None: + return None + if decode_attn_mask.shape[-1] == target_len and torch.all(kv_lens + 1 == target_len).item(): + return decode_attn_mask + compacted = torch.ones( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + compacted[batch_index, :, :, -current_len:] = decode_attn_mask[batch_index, :, :, -current_len:] + if not compacted.any().item(): + return None + return compacted + + +def _advance_decode_mask( + decode_attn_mask: Optional[torch.Tensor], + kv_lens: torch.LongTensor, +) -> Optional[torch.Tensor]: + if decode_attn_mask is None: + return None + target_len = int(kv_lens.max().item()) + 2 + advanced = torch.zeros( + (decode_attn_mask.shape[0], 1, 1, target_len), + dtype=decode_attn_mask.dtype, + device=decode_attn_mask.device, + ) + for batch_index, kv_len in enumerate(kv_lens.tolist()): + current_len = kv_len + 1 + next_mask = F.pad(decode_attn_mask[batch_index : batch_index + 1, :, :, -current_len:], (0, 1), value=False) + advanced[batch_index : batch_index + 1, :, :, -next_mask.shape[-1] :] = next_mask + if not advanced.any().item(): + return None + return advanced + + +def _sample_per_request( + model: Any, + active_batch: T2SActiveBatch, + logits: torch.Tensor, + max_steps: int, +) -> Tuple[List[T2SFinishedItem], List[int], List[torch.LongTensor]]: + finished_items: List[T2SFinishedItem] = [] + keep_indices: List[int] = [] + updated_sequences: List[torch.LongTensor] = [] + + uniform_sampling_key = _uniform_sampling_group_key(active_batch) + sampled_items: List[torch.Tensor] + argmax_tokens: List[int] + sampled_token_tensor: Optional[torch.Tensor] = None + argmax_token_tensor: Optional[torch.Tensor] = None + if uniform_sampling_key is not None: + sampled_tensor, argmax_tensor = _batched_sample_uniform( + logits=logits, + histories=active_batch.y_sequences, + sampling_key=uniform_sampling_key, + ) + sampled_token_tensor = sampled_tensor.view(-1) + argmax_token_tensor = argmax_tensor.view(-1) + if ( + all(state.early_stop_num == -1 for state in active_batch.states) + and int(active_batch.step_indices[0].item()) + 1 < max_steps + and not bool(sampled_token_tensor.eq(model.EOS).any().item()) + and not bool(argmax_token_tensor.eq(model.EOS).any().item()) + ): + return ( + [], + list(range(len(active_batch.states))), + [torch.cat([history, sampled_token_tensor[index : index + 1]], dim=0) for index, history in enumerate(active_batch.y_sequences)], + ) + sampled_items = [sampled_tensor[index : index + 1] for index in range(sampled_tensor.shape[0])] + argmax_tokens = [int(item) for item in argmax_tensor.tolist()] + else: + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=int(active_batch.step_indices[batch_index].item()) < 11, + ) + for batch_index, state in enumerate(active_batch.states) + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) + for batch_index, state in enumerate(active_batch.states): + step_index = int(active_batch.step_indices[batch_index].item()) + current_history = active_batch.y_sequences[batch_index] + if sampled_token_tensor is not None and argmax_token_tensor is not None: + sampled = sampled_token_tensor[batch_index : batch_index + 1] + sampled_token = int(sampled_token_tensor[batch_index].item()) + argmax_token = int(argmax_token_tensor[batch_index].item()) + else: + sampled = sampled_items[batch_index] + sampled_token = int(sampled[0, 0].item()) + argmax_token = argmax_tokens[batch_index] + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num: + finish_reason = "early_stop" + elif step_index + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + prefix_len = int(active_batch.prefix_lens[batch_index].item()) + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[prefix_len:-1].clone(), + finish_idx=step_index, + finish_reason=finish_reason, + ) + ) + else: + keep_indices.append(batch_index) + updated_sequences.append(new_history) + + return finished_items, keep_indices, updated_sequences + + +@torch.inference_mode() +def decode_one_step( + model: Any, + active_batch: T2SActiveBatch, + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + was_prefill = not active_batch.prefill_done + if was_prefill: + if active_batch.prefill_attn_mask is None or active_batch.key_padding_mask is None: + raise ValueError("prefill 阶段缺少必要 mask") + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt( + active_batch.xy_pos, active_batch.prefill_attn_mask, None + ) + active_batch.kv_lens = active_batch.x_lens + active_batch.prefix_lens + active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("prefill 阶段未生成完整 KV cache") + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens(active_batch.decode_attn_mask, active_batch.kv_lens) + active_batch.x = None + active_batch.x_lens = None + active_batch.key_padding_mask = None + active_batch.prefill_attn_mask = None + active_batch.prefill_done = True + else: + if active_batch.k_cache is None or active_batch.v_cache is None or active_batch.kv_lens is None: + raise ValueError("decode 阶段缺少 KV cache") + batched_decode_attn_mask = None + if active_batch.decode_attn_mask is not None: + batched_decode_attn_mask = _materialize_decode_mask_for_active_batch(active_batch) + if not batched_decode_attn_mask.any().item(): + batched_decode_attn_mask = None + xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token( + active_batch.xy_pos, + active_batch.k_cache, + active_batch.v_cache, + batched_decode_attn_mask, + ) + active_batch.decode_attn_mask = _advance_decode_mask(active_batch.decode_attn_mask, active_batch.kv_lens) + + logits = model.ar_predict_layer(xy_dec[:, -1]) + + finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps) + if len(keep_indices) == 0: + return None, finished_items + if len(keep_indices) == len(active_batch.request_ids): + active_batch.y_sequences = updated_sequences + active_batch.step_indices = active_batch.step_indices + 1 + if not was_prefill and active_batch.kv_lens is not None: + active_batch.kv_lens = active_batch.kv_lens + 1 + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + return active_batch, finished_items + + device = logits.device + keep_tensor = torch.LongTensor(keep_indices).to(device) + active_batch.request_ids = [active_batch.request_ids[i] for i in keep_indices] + active_batch.states = [active_batch.states[i] for i in keep_indices] + active_batch.y_sequences = updated_sequences + active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor) + next_step_indices = torch.index_select(active_batch.step_indices, dim=0, index=keep_tensor) + next_kv_lens = None if active_batch.kv_lens is None else torch.index_select(active_batch.kv_lens, dim=0, index=keep_tensor) + active_batch.step_indices = next_step_indices + 1 + if not was_prefill: + if next_kv_lens is not None: + active_batch.kv_lens = next_kv_lens + 1 + else: + active_batch.kv_lens = next_kv_lens + + if active_batch.decode_attn_mask is not None: + active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor) + if not active_batch.decode_attn_mask.any().item(): + active_batch.decode_attn_mask = None + if active_batch.k_cache is not None and active_batch.v_cache is not None: + for cache_index in range(len(active_batch.k_cache)): + active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor) + active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor) + if active_batch.kv_lens is not None: + active_batch.k_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.k_cache] + active_batch.v_cache = [_compact_cache_to_kv_lens(layer, active_batch.kv_lens) for layer in active_batch.v_cache] + active_batch.decode_attn_mask = _compact_decode_mask_to_kv_lens( + active_batch.decode_attn_mask, + active_batch.kv_lens, + ) + + active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences) + return active_batch, finished_items + + +def run_scheduler_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + return run_scheduler_continuous(model, states, max_steps=max_steps) + + +def _pad_cache_left(cache: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - cache.shape[1] + if pad_len <= 0: + return cache + return F.pad(cache, (0, 0, pad_len, 0), value=0) + + +def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor: + pad_len = target_len - mask.shape[-1] + if pad_len <= 0: + return mask + return F.pad(mask, (pad_len, 0), value=True) + + +def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor: + if mask.shape[-1] > target_len: + return mask[:, :, :, -target_len:] + if mask.shape[-1] < target_len: + return _pad_decode_mask_left(mask, target_len) + return mask + + +def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor: + expected_mask_len = running_request.k_cache[0].shape[1] + 1 + if running_request.decode_attn_mask is not None: + return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len) + current_mask_len = running_request.k_cache[0].shape[1] + 1 + return torch.zeros( + (1, 1, 1, current_mask_len), + dtype=torch.bool, + device=running_request.k_cache[0].device, + ) + + +def _materialize_decode_mask_for_active_batch( + active_batch: T2SActiveBatch, + target_mask_len: Optional[int] = None, +) -> torch.Tensor: + if active_batch.k_cache is None or active_batch.kv_lens is None: + raise ValueError("active batch 缺少 KV cache 或 kv_lens") + current_mask_len = active_batch.k_cache[0].shape[1] + 1 + if target_mask_len is None: + target_mask_len = current_mask_len + if active_batch.decode_attn_mask is None: + mask = torch.zeros( + (len(active_batch.request_ids), 1, 1, current_mask_len), + dtype=torch.bool, + device=active_batch.k_cache[0].device, + ) + else: + rows: List[torch.Tensor] = [] + for batch_index, kv_len in enumerate(active_batch.kv_lens.tolist()): + row_len = kv_len + 1 + row_mask = _fit_decode_mask_length( + active_batch.decode_attn_mask[batch_index : batch_index + 1], + row_len, + ) + rows.append(_pad_decode_mask_left(row_mask, target_mask_len)) + mask = torch.cat(rows, dim=0) + if target_mask_len != current_mask_len and active_batch.decode_attn_mask is None: + mask = _pad_decode_mask_left(mask, target_mask_len) + return mask + + +@torch.inference_mode() +def run_prefill_active_batch( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]: + if not states: + return None, [] + active_batch = build_prefill_batch(model, states) + return decode_one_step(model, active_batch, max_steps=max_steps) + + +@torch.inference_mode() +def merge_active_batches( + model: Any, + left_batch: Optional[T2SActiveBatch], + right_batch: Optional[T2SActiveBatch], +) -> Optional[T2SActiveBatch]: + if left_batch is None: + return right_batch + if right_batch is None: + return left_batch + if not left_batch.prefill_done or not right_batch.prefill_done: + raise ValueError("只有 prefill 完成后的 active batch 才能 merge") + if left_batch.k_cache is None or left_batch.v_cache is None or right_batch.k_cache is None or right_batch.v_cache is None: + raise ValueError("merge active batch 时缺少 KV cache") + + left_kv_len = int(left_batch.k_cache[0].shape[1]) + right_kv_len = int(right_batch.k_cache[0].shape[1]) + merged_kv_len = max(left_kv_len, right_kv_len) + merged_mask_len = merged_kv_len + 1 + + merged_k_cache: List[torch.Tensor] = [] + merged_v_cache: List[torch.Tensor] = [] + for layer_index in range(len(left_batch.k_cache)): + merged_k_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.k_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.k_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + merged_v_cache.append( + torch.cat( + [ + _pad_cache_left(left_batch.v_cache[layer_index], merged_kv_len), + _pad_cache_left(right_batch.v_cache[layer_index], merged_kv_len), + ], + dim=0, + ) + ) + + merged_decode_attn_mask = torch.cat( + [ + _materialize_decode_mask_for_active_batch(left_batch, merged_mask_len), + _materialize_decode_mask_for_active_batch(right_batch, merged_mask_len), + ], + dim=0, + ) + merged_request_ids = list(left_batch.request_ids) + list(right_batch.request_ids) + merged_states = list(left_batch.states) + list(right_batch.states) + merged_y_sequences = list(left_batch.y_sequences) + list(right_batch.y_sequences) + merged_prefix_lens = torch.cat([left_batch.prefix_lens, right_batch.prefix_lens], dim=0) + if left_batch.kv_lens is None or right_batch.kv_lens is None: + raise ValueError("merge active batch 时缺少 kv_lens") + merged_kv_lens = torch.cat([left_batch.kv_lens, right_batch.kv_lens], dim=0) + merged_decode_attn_mask = _compact_decode_mask_to_kv_lens(merged_decode_attn_mask, merged_kv_lens) + merged_step_indices = torch.cat([left_batch.step_indices, right_batch.step_indices], dim=0) + + return T2SActiveBatch( + request_ids=merged_request_ids, + states=merged_states, + x=None, + x_lens=None, + y_sequences=merged_y_sequences, + prefix_lens=merged_prefix_lens, + xy_pos=build_next_xy_pos(model, merged_y_sequences), + key_padding_mask=None, + prefill_attn_mask=None, + decode_attn_mask=merged_decode_attn_mask, + k_cache=merged_k_cache, + v_cache=merged_v_cache, + kv_lens=merged_kv_lens, + step_indices=merged_step_indices, + prefill_done=True, + ) + + +@torch.inference_mode() +def run_prefill_step( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not states: + return [], [] + + active_batch = build_prefill_batch(model, states) + xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None) + decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False) + if len(states) == 1 and not decode_attn_mask.any().item(): + decode_attn_mask = None + logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=state.top_k, + top_p=state.top_p, + temperature=state.temperature, + repetition_penalty=state.repetition_penalty, + trim_eos=True, + ) + for state in states + ] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=active_batch.y_sequences, + sampling_keys=sampling_keys, + ) + + running_requests: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, state in enumerate(states): + current_history = active_batch.y_sequences[batch_index] + sampled = sampled_items[batch_index] + sampled_token = int(sampled[0, 0].item()) + argmax_token = argmax_tokens[batch_index] + new_history = torch.cat([current_history, sampled.view(-1)], dim=0) + prefix_len = int(active_batch.prefix_lens[batch_index].item()) + + finish_reason: Optional[str] = None + if state.early_stop_num != -1 and (new_history.shape[0] - prefix_len) > state.early_stop_num: + finish_reason = "early_stop" + elif 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=state.request_id, + semantic_tokens=new_history[prefix_len:-1].clone(), + finish_idx=0, + finish_reason=finish_reason, + ) + ) + continue + + real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len + request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache] + request_decode_attn_mask = None + if decode_attn_mask is not None: + request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone() + request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1) + if not request_decode_attn_mask.any().item(): + request_decode_attn_mask = None + + running_requests.append( + T2SRunningRequest( + state=state, + y_sequence=new_history, + prefix_len=prefix_len, + decode_attn_mask=request_decode_attn_mask, + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=1, + ) + ) + + return running_requests, finished_items + + +def _build_decode_batch_from_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]: + xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests]) + max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests) + num_layers = len(running_requests[0].k_cache) + + batched_k_cache: List[torch.Tensor] = [] + batched_v_cache: List[torch.Tensor] = [] + for layer_index in range(num_layers): + batched_k_cache.append( + torch.cat([_pad_cache_left(item.k_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + batched_v_cache.append( + torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0) + ) + + if all(item.decode_attn_mask is None for item in running_requests): + batched_decode_attn_mask = None + else: + materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests] + max_mask_len = max(mask.shape[-1] for mask in materialized_masks) + batched_decode_attn_mask = torch.cat( + [_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks], + dim=0, + ) + return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask + + +@torch.inference_mode() +def run_decode_step_for_running( + model: Any, + running_requests: Sequence[T2SRunningRequest], + max_steps: int, +) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]: + if not running_requests: + return [], [] + + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = _build_decode_batch_from_running( + model, running_requests + ) + xy_dec, next_k_cache, next_v_cache = model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ) + logits = model.ar_predict_layer(xy_dec[:, -1]) + sampling_keys = [ + _sampling_group_key( + top_k=running_request.state.top_k, + top_p=running_request.state.top_p, + temperature=running_request.state.temperature, + repetition_penalty=running_request.state.repetition_penalty, + trim_eos=running_request.step_idx < 11, + ) + for running_request in running_requests + ] + histories = [running_request.y_sequence for running_request in running_requests] + sampled_items, argmax_tokens = _batched_sample_by_group( + logits=logits, + histories=histories, + sampling_keys=sampling_keys, + ) + + next_running: List[T2SRunningRequest] = [] + finished_items: List[T2SFinishedItem] = [] + + for batch_index, running_request in enumerate(running_requests): + current_idx = running_request.step_idx + sampled = sampled_items[batch_index] + sampled_token = int(sampled[0, 0].item()) + argmax_token = argmax_tokens[batch_index] + new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0) + + finish_reason: Optional[str] = None + if running_request.state.early_stop_num != -1 and (new_history.shape[0] - running_request.prefix_len) > running_request.state.early_stop_num: + finish_reason = "early_stop" + elif current_idx + 1 >= max_steps: + finish_reason = "max_step" + elif sampled_token == model.EOS: + finish_reason = "eos_sample" + elif argmax_token == model.EOS: + finish_reason = "eos_argmax" + + if finish_reason is not None: + finished_items.append( + T2SFinishedItem( + request_id=running_request.state.request_id, + semantic_tokens=new_history[running_request.prefix_len:-1].clone(), + finish_idx=current_idx, + finish_reason=finish_reason, + ) + ) + continue + + real_next_kv_len = running_request.k_cache[0].shape[1] + 1 + request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache] + request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache] + if batched_decode_attn_mask is None: + next_decode_attn_mask = None + else: + current_decode_mask_len = running_request.k_cache[0].shape[1] + 1 + current_decode_attn_mask = batched_decode_attn_mask[ + batch_index : batch_index + 1, :, :, -current_decode_mask_len: + ] + next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False) + next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1) + if not next_decode_attn_mask.any().item(): + next_decode_attn_mask = None + next_running.append( + T2SRunningRequest( + state=running_request.state, + y_sequence=new_history, + prefix_len=running_request.prefix_len, + decode_attn_mask=next_decode_attn_mask, + k_cache=request_k_cache, + v_cache=request_v_cache, + step_idx=current_idx + 1, + ) + ) + + return next_running, finished_items + + +@torch.inference_mode() +def run_scheduler_continuous( + model: Any, + states: Sequence[T2SRequestState], + max_steps: int, +) -> List[T2SFinishedItem]: + pending = sorted(states, key=lambda item: (item.ready_step, item.request_id)) + active_batch: Optional[T2SActiveBatch] = None + finished: List[T2SFinishedItem] = [] + current_tick = 0 + + while pending or active_batch is not None: + admitted: List[T2SRequestState] = [] + while pending and pending[0].ready_step <= current_tick: + admitted.append(pending.pop(0)) + + admitted_active_batch, admitted_finished = run_prefill_active_batch(model, admitted, max_steps=max_steps) + finished.extend(admitted_finished) + active_batch = merge_active_batches(model, active_batch, admitted_active_batch) + + if active_batch is not None: + active_batch, step_finished = decode_one_step(model, active_batch, max_steps=max_steps) + finished.extend(step_finished) + + if active_batch is None and pending: + current_tick = max(current_tick + 1, pending[0].ready_step) + continue + + current_tick += 1 + + finished.sort(key=lambda item: item.request_id) + return finished diff --git a/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py new file mode 100644 index 00000000..3d5b2de5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py @@ -0,0 +1,112 @@ +import os +import re +import sys +from typing import Dict, List, Optional, Tuple + +now_dir = os.getcwd() +sys.path.append(now_dir) + +from text.LangSegmenter import LangSegmenter +from text import cleaned_text_to_sequence +from text import chinese2 +from text.cleaner import clean_text + + +PreparedTextSegmentPayload = Dict[str, object] + + +def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]: + textlist: List[str] = [] + langlist: List[str] = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or ( + tmp["lang"] != "en" and langlist[-1] != "en" + ) + if same_group: + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + langlist.append(language) + textlist.append(tmp["text"]) + return textlist, langlist + + +def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]: + normalized_language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, normalized_language, version) + phones = cleaned_text_to_sequence(phones, version) + return list(phones), None if word2ph is None else list(word2ph), str(norm_text) + + +def preprocess_text_segments_payload( + text: str, + language: str, + version: str, + final: bool = False, +) -> List[PreparedTextSegmentPayload]: + text = re.sub(r" {2,}", " ", text) + textlist, langlist = split_text_by_language(text, language) + payloads: List[PreparedTextSegmentPayload] = [] + total_phones_len = 0 + for segment_text, segment_lang in zip(textlist, langlist): + normalized_language = segment_lang.replace("all_", "") + if normalized_language == "zh": + norm_text = chinese2.text_normalize(segment_text) + phones = [] + word2ph = None + needs_g2pw = True + estimated_phones_len = max(0, len(norm_text) * 2) + else: + phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version) + needs_g2pw = False + estimated_phones_len = len(phones) + payloads.append( + { + "language": normalized_language, + "phones": phones, + "word2ph": word2ph, + "norm_text": norm_text, + "needs_g2pw": needs_g2pw, + } + ) + total_phones_len += int(estimated_phones_len) + + if not final and total_phones_len < 6: + return preprocess_text_segments_payload("." + text, language, version, final=True) + + return payloads diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine.py b/GPT_SoVITS/TTS_infer_pack/unified_engine.py new file mode 100644 index 00000000..a6faaddb --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import os +from typing import Sequence + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks +from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface + + +class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates): + @staticmethod + def _env_flag(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return bool(default) + return str(value).strip().lower() not in {"0", "false", "no", "off", ""} + + @staticmethod + def _env_int(name: str, default: int) -> int: + value = os.environ.get(name) + if value in [None, ""]: + return int(default) + return int(value) + + @staticmethod + def _env_float(name: str, default: float) -> float: + value = os.environ.get(name) + if value in [None, ""]: + return float(default) + return float(value) + + def __init__( + self, + tts: TTS, + cut_method_names: Sequence[str], + control_callbacks: RuntimeControlCallbacks | None = None, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + ) -> None: + self.tts = tts + self.cut_method_names = set(cut_method_names) + self.control_callbacks = control_callbacks or RuntimeControlCallbacks() + EngineCompositionBuilder(self).build(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py new file mode 100644 index 00000000..0895cef4 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api.py @@ -0,0 +1,451 @@ +from __future__ import annotations + +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_direct import EngineApiDirectFlow +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_profile import ( + aggregate_numeric_dicts, + build_direct_scheduler_profile, + build_direct_segment_trace, + build_legacy_direct_profile, + build_request_meta, + build_scheduler_debug_batch_profile, + build_scheduler_debug_request_profile, + build_scheduler_submit_headers, + build_scheduler_submit_profile, + format_ms_header, + sum_profile_field, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_request import ( + apply_default_reference, + base_request_defaults, + check_params, + is_aux_ref_enabled, + normalize_engine_request, + normalize_lang, + normalize_streaming_mode, + select_direct_backend, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_scheduler import EngineApiSchedulerFlow +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + DirectTTSExecution, + NormalizedEngineRequest, + SchedulerDebugExecution, + SchedulerSubmitExecution, +) + + +class EngineApiFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + self.direct_flow = EngineApiDirectFlow(self) + self.scheduler_flow = EngineApiSchedulerFlow(self) + + @property + def tts(self): + return self.owner.tts + + @property + def cut_method_names(self): + return self.owner.cut_method_names + + @property + def reference_registry(self): + return self.owner.reference_registry + + @property + def direct_tts_lock(self): + return self.owner.direct_tts_lock + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ): + return self.owner._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.owner._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.owner._merge_request_state_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.owner._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.owner._fail_request_state(request_id, error) + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.owner._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + super_sampling: bool, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ): + return await self.owner._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + super_sampling=super_sampling, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.owner.request_registry.collect_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.owner.request_registry.has_active(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + return build_request_meta(payload) + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + return sum_profile_field(items, key) + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + + def _build_direct_scheduler_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + pack_ms: float, + response_overhead_ms: float, + ) -> Dict[str, Any]: + return build_direct_scheduler_profile( + backend=backend, + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=audio_bytes, + sample_rate=sample_rate, + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + ) + + def _build_legacy_direct_profile( + self, + *, + backend: str, + fallback_reason: str | None, + request_start: float, + finished_at: float, + sample_rate: int | None = None, + audio_bytes: int = 0, + pack_ms: float = 0.0, + chunk_count: int = 0, + stream_total_bytes: int = 0, + first_chunk_ms: float | None = None, + ) -> Dict[str, Any]: + return build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=finished_at, + sample_rate=sample_rate, + audio_bytes=audio_bytes, + pack_ms=pack_ms, + chunk_count=chunk_count, + stream_total_bytes=stream_total_bytes, + first_chunk_ms=first_chunk_ms, + ) + + def _build_scheduler_submit_profile( + self, + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + prepare_spec_build_ms: float, + prepare_wall_ms: float, + prepare_executor_queue_ms: float, + prepare_executor_run_ms: float, + prepare_profile_total_ms: float, + prepare_profile_wall_ms: float, + prepare_other_ms: float, + engine_policy_wait_ms: float, + api_after_prepare_ms: float, + api_wait_result_ms: float, + pack_ms: float, + response_overhead_ms: float, + worker_profile: Dict[str, Any], + ) -> Dict[str, Any]: + return build_scheduler_submit_profile( + backend=backend, + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=audio_bytes, + sample_rate=sample_rate, + prepare_spec_build_ms=prepare_spec_build_ms, + prepare_wall_ms=prepare_wall_ms, + prepare_executor_queue_ms=prepare_executor_queue_ms, + prepare_executor_run_ms=prepare_executor_run_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + prepare_profile_wall_ms=prepare_profile_wall_ms, + prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=engine_policy_wait_ms, + api_after_prepare_ms=api_after_prepare_ms, + api_wait_result_ms=api_wait_result_ms, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + worker_profile=worker_profile, + ) + + @staticmethod + def _format_ms_header(value: Any) -> str: + return format_ms_header(value) + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + return build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, + ) + + def _build_scheduler_debug_request_profile( + self, + *, + state: T2SRequestState, + item: T2SFinishedItem, + batch_request_count: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + batch_request_total_ms: float, + ) -> Dict[str, Any]: + return build_scheduler_debug_request_profile( + state=state, + item=item, + batch_request_count=batch_request_count, + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + batch_request_total_ms=batch_request_total_ms, + ) + + @staticmethod + def _build_scheduler_debug_batch_profile( + *, + request_count: int, + max_steps: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + request_total_ms: float, + finished_items: Sequence[T2SFinishedItem], + ) -> Dict[str, Any]: + return build_scheduler_debug_batch_profile( + request_count=request_count, + max_steps=max_steps, + prepare_batch_wall_ms=prepare_batch_wall_ms, + decode_batch_wall_ms=decode_batch_wall_ms, + request_total_ms=request_total_ms, + finished_items=finished_items, + ) + + def _normalize_lang(self, value: str | None) -> str | None: + return normalize_lang(value) + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + return aggregate_numeric_dicts(items) + + def _apply_default_reference(self, req: dict) -> dict: + return apply_default_reference(self.reference_registry, req) + + def check_params(self, req: dict) -> Optional[str]: + return check_params(self.tts, self.cut_method_names, req) + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return base_request_defaults() + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + return normalize_engine_request( + tts=self.tts, + cut_method_names=self.cut_method_names, + reference_registry=self.reference_registry, + payload=payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + return normalize_streaming_mode(req) + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return is_aux_ref_enabled(aux_ref_audio_paths) + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return select_direct_backend(normalized) + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + yield from self.direct_flow._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + return self.direct_flow._should_use_scheduler_backend_for_direct(req) + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + return self.direct_flow._segment_direct_text(normalized) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + return self.direct_flow._build_segment_request( + normalized, + request_id=request_id, + text=text, + ) + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + return await self.direct_flow._run_direct_tts_via_scheduler(normalized) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return self.direct_flow._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return await self.direct_flow._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + return await self.direct_flow.run_direct_tts_async(req) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + return self.direct_flow.run_direct_tts(req) + + def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + return self.scheduler_flow._build_scheduler_request_specs(request_items) + + def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + return self.scheduler_flow._build_scheduler_submit_spec(payload) + + @staticmethod + def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return EngineApiSchedulerFlow._summarize_scheduler_states(states) + + @staticmethod + def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return EngineApiSchedulerFlow._summarize_scheduler_finished(items) + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + return await self.scheduler_flow.run_scheduler_debug(request_items, max_steps, seed) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + return await self.scheduler_flow.run_scheduler_submit(payload) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py new file mode 100644 index 00000000..f42ec233 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, NormalizedEngineRequest + + +class EngineApiDelegates: + def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]: + return self.api_facade._collect_request_summaries(request_ids) + + def _has_active_request(self, request_id: str) -> bool: + return self.api_facade._has_active_request(request_id) + + @staticmethod + def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + return EngineApiFacade._build_request_meta(payload) + + @staticmethod + def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + return EngineApiFacade._sum_profile_field(items, key) + + def _build_direct_segment_trace( + self, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + + def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_direct_scheduler_profile(**kwargs) + + def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_legacy_direct_profile(**kwargs) + + def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_submit_profile(**kwargs) + + @staticmethod + def _format_ms_header(value: Any) -> str: + return EngineApiFacade._format_ms_header(value) + + def _build_scheduler_submit_headers( + self, + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], + ) -> Dict[str, str]: + return self.api_facade._build_scheduler_submit_headers( + request_id=request_id, + media_type=media_type, + sample_rate=sample_rate, + profile=profile, + ) + + def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]: + return self.api_facade._build_scheduler_debug_request_profile(**kwargs) + + @staticmethod + def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]: + return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs) + + def _normalize_lang(self, value: str | None) -> str | None: + return self.api_facade._normalize_lang(value) + + @staticmethod + def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + return EngineApiFacade._aggregate_numeric_dicts(items) + + def _apply_default_reference(self, req: dict) -> dict: + return self.api_facade._apply_default_reference(req) + + def check_params(self, req: dict) -> Optional[str]: + return self.api_facade.check_params(req) + + @staticmethod + def _base_request_defaults() -> Dict[str, Any]: + return EngineApiFacade._base_request_defaults() + + def _normalize_engine_request( + self, + payload: dict | NormalizedEngineRequest, + *, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", + ) -> NormalizedEngineRequest: + return self.api_facade._normalize_engine_request( + payload, + request_id=request_id, + normalize_streaming=normalize_streaming, + error_prefix=error_prefix, + ) + + @staticmethod + def _normalize_streaming_mode(req: dict) -> dict: + return EngineApiFacade._normalize_streaming_mode(req) + + @staticmethod + def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths) + + def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return self.api_facade._select_direct_backend(normalized) + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + return self.api_facade._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + return self.api_facade._should_use_scheduler_backend_for_direct(req) + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + return self.api_facade._segment_direct_text(normalized) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text) + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_scheduler(normalized) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return self.api_facade._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + return await self.api_facade._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py new file mode 100644 index 00000000..f55da45d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py @@ -0,0 +1,595 @@ +from __future__ import annotations + +import asyncio +import queue +import threading +import time +import uuid +from io import BytesIO +from typing import Any, Dict, Generator, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, wave_header_chunk +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineStatus, NormalizedEngineRequest, SchedulerPendingJob + + +class EngineApiDirectFlow: + def __init__(self, api: Any) -> None: + self.api = api + + def _iter_legacy_direct_tts_bytes( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> Generator[bytes, None, None]: + payload = normalized.to_payload() + media_type = normalized.media_type + request_id = normalized.request_id + request_start = time.perf_counter() + chunk_count = 0 + stream_total_bytes = 0 + first_chunk_ms: float | None = None + self.api._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + try: + with self.api.direct_tts_lock: + tts_generator = self.api.tts.run(payload) + first_chunk = True + current_media_type = media_type + for sr, chunk in tts_generator: + if first_chunk: + first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + self.api._update_request_state( + request_id, + EngineStatus.STREAMING, + { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "sample_rate": int(sr), + }, + ) + if first_chunk and media_type == "wav": + header = wave_header_chunk(sample_rate=sr) + chunk_count += 1 + stream_total_bytes += len(header) + yield header + current_media_type = "raw" + first_chunk = False + elif first_chunk: + first_chunk = False + packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_chunk) + yield packed_chunk + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + raise + self.api._complete_request_state( + request_id, + dict( + self.api._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + audio_bytes=stream_total_bytes, + chunk_count=chunk_count, + stream_total_bytes=stream_total_bytes, + first_chunk_ms=first_chunk_ms, + ), + streaming_completed=True, + ), + ) + + def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool: + if isinstance(req, NormalizedEngineRequest): + normalized = req + else: + normalized = self.api._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + ) + backend, _ = self.api._select_direct_backend(normalized) + return backend == "scheduler_v1_direct" + + def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]: + payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized + return self.api.tts.text_preprocessor.pre_seg_text( + str(payload["text"]), + str(payload["text_lang"]), + str(payload.get("text_split_method", "cut5")), + ) + + def _build_segment_request( + self, + normalized: NormalizedEngineRequest, + *, + request_id: str, + text: str, + ) -> NormalizedEngineRequest: + payload = normalized.to_payload() + payload["request_id"] = request_id + payload["text"] = text + payload["streaming_mode"] = False + payload["return_fragment"] = False + payload["fixed_length_chunk"] = False + payload["response_streaming"] = False + return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ") + + async def _execute_single_segment_scheduler_job( + self, + normalized: NormalizedEngineRequest, + *, + segment_request: NormalizedEngineRequest, + ) -> tuple[SchedulerPendingJob, Dict[str, Any]]: + spec = self.api._build_scheduler_submit_spec(segment_request) + state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=None, + ) + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms)) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + await self.api._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + super_sampling=bool(normalized.super_sampling), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=None, + timeout_sec=normalized.timeout_sec, + ) + timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) + job: SchedulerPendingJob = await asyncio.wait_for(done_future, timeout=timeout_sec) + return job, { + "request_id": spec.request_id, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": dict(state.prepare_profile), + } + + def _iter_scheduler_direct_tts_bytes(self, normalized: NormalizedEngineRequest) -> Generator[bytes, None, None]: + request_start = time.perf_counter() + request_id = normalized.request_id + media_type = normalized.media_type + segment_texts = self._segment_direct_text(normalized) + if not segment_texts: + raise ValueError("text preprocessing returned no valid segments") + chunk_queue: queue.Queue[object] = queue.Queue(maxsize=8) + done_marker = object() + + async def _produce_chunks() -> None: + self.api._update_request_state( + request_id, + EngineStatus.CPU_PREPARING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, + ) + sample_rate: int | None = None + current_media_type = media_type + chunk_count = 0 + stream_total_bytes = 0 + first_chunk_ms: float | None = None + prepare_profiles: List[Dict[str, Any]] = [] + worker_profiles: List[Dict[str, Any]] = [] + try: + for segment_index, segment_text in enumerate(segment_texts): + segment_request = self._build_segment_request( + normalized, + request_id=f"{request_id}_seg_{segment_index:03d}", + text=segment_text, + ) + self.api._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + { + "backend": "scheduler_v1_direct", + "backend_mode": "scheduler_v1_direct", + "segment_index": segment_index, + "segment_count": len(segment_texts), + }, + ) + job, prepare_profile = await self._execute_single_segment_scheduler_job( + normalized, + segment_request=segment_request, + ) + prepare_profiles.append(prepare_profile) + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + worker_profiles.append(dict(job.result)) + if sample_rate is None: + sample_rate = int(job.sample_rate) + first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0) + self.api._update_request_state( + request_id, + EngineStatus.STREAMING, + { + "backend": "scheduler_v1_direct", + "backend_mode": "scheduler_v1_direct", + "sample_rate": int(sample_rate), + }, + ) + if media_type == "wav": + header = wave_header_chunk(sample_rate=int(sample_rate)) + chunk_count += 1 + stream_total_bytes += len(header) + chunk_queue.put(header) + current_media_type = "raw" + packed_chunk = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), current_media_type).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_chunk) + chunk_queue.put(packed_chunk) + if segment_index + 1 < len(segment_texts): + silence_samples = int(float(normalized.fragment_interval) * float(job.sample_rate)) + if silence_samples > 0: + silence_chunk = np.zeros(silence_samples, dtype=np.int16) + packed_silence = pack_audio( + BytesIO(), silence_chunk, int(job.sample_rate), current_media_type + ).getvalue() + chunk_count += 1 + stream_total_bytes += len(packed_silence) + chunk_queue.put(packed_silence) + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + chunk_queue.put(exc) + else: + self.api._merge_request_state_profile( + request_id, + { + "prepare_aggregate": self.api._aggregate_numeric_dicts( + [item["prepare_profile"] for item in prepare_profiles] + ), + "engine_policy_wait_ms": sum( + float(item.get("engine_policy_wait_ms", 0.0)) for item in worker_profiles + ), + "engine_dispatch_wait_ms": sum( + float(item.get("engine_dispatch_wait_ms", 0.0)) for item in worker_profiles + ), + }, + ) + direct_profile = self.api._build_direct_scheduler_profile( + backend="scheduler_v1_direct", + request_start=request_start, + response_ready_at=time.perf_counter(), + audio_bytes=stream_total_bytes, + sample_rate=int(sample_rate or 0), + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=0.0, + response_overhead_ms=0.0, + ) + self.api._complete_request_state( + request_id, + dict(direct_profile, streaming_completed=True, first_chunk_ms=first_chunk_ms), + ) + finally: + chunk_queue.put(done_marker) + + producer_thread = threading.Thread(target=lambda: asyncio.run(_produce_chunks()), daemon=True) + producer_thread.start() + while True: + item = chunk_queue.get() + if item is done_marker: + break + if isinstance(item, Exception): + raise item + yield item + + async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution: + request_start = time.perf_counter() + request_id = normalized.request_id + media_type = normalized.media_type + segment_texts = self._segment_direct_text(normalized) + if not segment_texts: + raise ValueError("text preprocessing returned no valid segments") + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_scheduler_direct_tts_bytes(normalized), + request_id=request_id, + ) + self.api._update_request_state( + request_id, + EngineStatus.CPU_PREPARING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)}, + ) + segment_requests = [ + self._build_segment_request( + normalized, + request_id=f"{request_id}_seg_{segment_index:03d}", + text=segment_text, + ) + for segment_index, segment_text in enumerate(segment_texts) + ] + prepare_profiles: List[Dict[str, Any]] = [] + loop = asyncio.get_running_loop() + done_futures: List[asyncio.Future] = [] + self.api._update_request_state( + request_id, + EngineStatus.READY_FOR_PREFILL, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_requests)}, + ) + prepared_items = await asyncio.gather( + *[ + self._execute_single_segment_scheduler_job( + normalized, + segment_request=segment_request, + ) + for segment_request in segment_requests + ] + ) + for job, prepare_profile in prepared_items: + prepare_profiles.append(prepare_profile) + done_future = loop.create_future() + done_future.set_result(job) + done_futures.append(done_future) + self.api._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0) + jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec)) + for profile_item, job in zip(prepare_profiles, jobs): + profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms) + profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms) + self.api._merge_request_state_profile( + request_id, + { + "engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs), + "engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs), + "prepare_aggregate": self.api._aggregate_numeric_dicts([item["prepare_profile"] for item in prepare_profiles]), + }, + ) + + sample_rate: int | None = None + audio_parts: List[np.ndarray] = [] + worker_profiles: List[Dict[str, Any]] = [] + fragment_interval = float(normalized.fragment_interval) + silence_chunk: Optional[np.ndarray] = None + for job in jobs: + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + raise RuntimeError(f"{job.request_id} finished without audio result") + if sample_rate is None: + sample_rate = int(job.sample_rate) + silence_samples = int(fragment_interval * float(sample_rate)) + if silence_samples > 0: + silence_chunk = np.zeros(silence_samples, dtype=np.int16) + elif int(job.sample_rate) != sample_rate: + raise RuntimeError("segment sample rate mismatch") + audio_parts.append(job.audio_data) + if silence_chunk is not None: + audio_parts.append(silence_chunk.copy()) + worker_profiles.append(dict(job.result)) + if sample_rate is None or not audio_parts: + raise RuntimeError("direct scheduler backend produced no audio") + self.api._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"}, + ) + merged_audio = np.concatenate(audio_parts, axis=0) + pack_start = time.perf_counter() + audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + direct_profile = self.api._build_direct_scheduler_profile( + backend="scheduler_v1_direct", + request_start=request_start, + response_ready_at=time.perf_counter(), + audio_bytes=len(audio_bytes), + sample_rate=int(sample_rate), + segment_texts=segment_texts, + prepare_profiles=prepare_profiles, + worker_profiles=worker_profiles, + pack_ms=pack_ms, + response_overhead_ms=0.0, + ) + self.api._complete_request_state( + request_id, + dict(direct_profile, streaming_completed=False), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=audio_bytes, + request_id=request_id, + ) + + def _run_legacy_direct_tts_blocking( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + normalized_payload = normalized.to_payload() + request_id = normalized.request_id + media_type = normalized.media_type + request_start = time.perf_counter() + self.api._update_request_state( + request_id, + EngineStatus.ACTIVE_DECODE, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + with self.api.direct_tts_lock: + tts_generator = self.api.tts.run(normalized_payload) + try: + sr, audio_data = next(tts_generator) + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + raise + self.api._update_request_state( + request_id, + EngineStatus.FINALIZING, + {"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason}, + ) + pack_start = time.perf_counter() + packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0) + self.api._complete_request_state( + request_id, + dict( + self.api._build_legacy_direct_profile( + backend=backend, + fallback_reason=fallback_reason, + request_start=request_start, + finished_at=time.perf_counter(), + sample_rate=int(sr), + audio_bytes=len(packed_audio), + pack_ms=pack_ms, + ), + streaming_completed=False, + ), + ) + return DirectTTSExecution( + media_type=media_type, + streaming=False, + audio_bytes=packed_audio, + request_id=request_id, + ) + + async def _run_direct_tts_via_legacy_backend( + self, + normalized: NormalizedEngineRequest, + *, + backend: str, + fallback_reason: str | None, + ) -> DirectTTSExecution: + if normalized.response_streaming: + return DirectTTSExecution( + media_type=normalized.media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=normalized.request_id, + ) + return await asyncio.to_thread( + self._run_legacy_direct_tts_blocking, + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + normalized = self.api._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self.api._select_direct_backend(normalized) + self.api._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + deadline_ts=(time.perf_counter() + float(normalized.timeout_sec) if normalized.timeout_sec is not None else None), + meta=self.api._build_request_meta(normalized.to_payload()), + ) + self.api._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend == "scheduler_v1_direct": + try: + return await self._run_direct_tts_via_scheduler(normalized) + except Exception as exc: + self.api._fail_request_state(request_id, str(exc)) + raise + return await self._run_direct_tts_via_legacy_backend( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + normalized = self.api._normalize_engine_request( + req, + request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"), + normalize_streaming=True, + error_prefix="", + ) + request_id = normalized.request_id + media_type = normalized.media_type + backend, fallback_reason = self.api._select_direct_backend(normalized) + if not self.api._has_active_request(request_id): + self.api._register_request_state( + request_id=request_id, + api_mode="tts", + backend=backend, + media_type=media_type, + response_streaming=bool(normalized.response_streaming), + meta=self.api._build_request_meta(normalized.to_payload()), + ) + self.api._update_request_state( + request_id, + EngineStatus.VALIDATED, + { + "request_source": "direct_tts", + "selected_backend": backend, + "fallback_reason": fallback_reason, + }, + ) + if backend != "scheduler_v1_direct": + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend=backend, + fallback_reason=fallback_reason, + ) + if normalized.response_streaming: + return DirectTTSExecution( + media_type=media_type, + streaming=True, + audio_generator=self._iter_legacy_direct_tts_bytes( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ), + request_id=request_id, + ) + return self._run_legacy_direct_tts_blocking( + normalized, + backend="legacy_direct_sync_compat", + fallback_reason="sync_direct_compat", + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py new file mode 100644 index 00000000..e31c5dfe --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Sequence + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState + + +def build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]: + text = payload.get("text") + prompt_text = payload.get("prompt_text") + return { + "text_len": 0 if text is None else len(str(text)), + "prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)), + "text_lang": payload.get("text_lang"), + "prompt_lang": payload.get("prompt_lang"), + "ref_audio_path": payload.get("ref_audio_path"), + } + + +def sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float: + total = 0.0 + for item in items: + value = item.get(key, 0.0) + if isinstance(value, (int, float)): + total += float(value) + return total + + +def aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]: + totals: Dict[str, float] = {} + for item in items: + for key, value in item.items(): + if isinstance(value, (int, float)): + totals[key] = totals.get(key, 0.0) + float(value) + return totals + + +def build_direct_segment_trace( + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], +) -> List[Dict[str, Any]]: + results: List[Dict[str, Any]] = [] + for index, segment_text in enumerate(segment_texts): + prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {} + worker_item = worker_profiles[index] if index < len(worker_profiles) else {} + prepare_profile = dict(prepare_item.get("prepare_profile", {})) + results.append( + { + "segment_index": index, + "request_id": prepare_item.get("request_id") or worker_item.get("request_id"), + "text_len": len(str(segment_text)), + "prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)), + "prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)), + "prepare_engine_gpu_queue_wait_ms": float( + dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0) + ), + "engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)), + "engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)), + "decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)), + "queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)), + "prefill_ms": float(worker_item.get("prefill_ms", 0.0)), + "merge_ms": float(worker_item.get("merge_ms", 0.0)), + "decode_ms": float(worker_item.get("decode_ms", 0.0)), + "finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)), + "synth_ms": float(worker_item.get("synth_ms", 0.0)), + "worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)), + "decode_steps": int(worker_item.get("decode_steps", 0)), + "semantic_len": int(worker_item.get("semantic_len", 0)), + "finish_reason": worker_item.get("finish_reason"), + "norm_text": prepare_profile.get("norm_text"), + } + ) + return results + + +def build_direct_scheduler_profile( + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + segment_texts: Sequence[str], + prepare_profiles: Sequence[Dict[str, Any]], + worker_profiles: Sequence[Dict[str, Any]], + pack_ms: float, + response_overhead_ms: float, +) -> Dict[str, Any]: + segment_trace = build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles) + prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles] + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + prepare_wall_ms = sum_profile_field(prepare_profiles, "prepare_wall_ms") + prepare_profile_total_ms = sum_profile_field(prepare_profiles, "prepare_profile_total_ms") + engine_policy_wait_ms = sum_profile_field(prepare_profiles, "engine_policy_wait_ms") + engine_dispatch_wait_ms = sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms") + decode_admission_wait_ms = sum_profile_field(worker_profiles, "decode_admission_wait_ms") + queue_wait_ms = sum_profile_field(worker_profiles, "queue_wait_ms") + prefill_ms = sum_profile_field(worker_profiles, "prefill_ms") + merge_ms = sum_profile_field(worker_profiles, "merge_ms") + decode_ms = sum_profile_field(worker_profiles, "decode_ms") + finalize_wait_ms = sum_profile_field(worker_profiles, "finalize_wait_ms") + synth_ms = sum_profile_field(worker_profiles, "synth_ms") + worker_total_ms = sum_profile_field(worker_profiles, "worker_total_ms") + decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles) + semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles) + request_other_ms = max( + 0.0, + request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms, + ) + return { + "backend": backend, + "backend_mode": backend, + "segment_count": len(segment_texts), + "sample_rate": int(sample_rate), + "audio_bytes": int(audio_bytes), + "request_total_ms": request_total_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "engine_policy_wait_ms": engine_policy_wait_ms, + "engine_dispatch_wait_ms": engine_dispatch_wait_ms, + "decode_admission_wait_ms": decode_admission_wait_ms, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": prefill_ms, + "merge_ms": merge_ms, + "decode_ms": decode_ms, + "finalize_wait_ms": finalize_wait_ms, + "synth_ms": synth_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "worker_total_ms": worker_total_ms, + "request_other_ms": request_other_ms, + "decode_steps": decode_steps, + "semantic_len": semantic_len, + "prepare_segments": list(prepare_profiles), + "worker_segments": list(worker_profiles), + "segment_trace": segment_trace, + "prepare_aggregate": aggregate_numeric_dicts(prepare_profile_dicts), + } + + +def build_legacy_direct_profile( + *, + backend: str, + fallback_reason: str | None, + request_start: float, + finished_at: float, + sample_rate: int | None = None, + audio_bytes: int = 0, + pack_ms: float = 0.0, + chunk_count: int = 0, + stream_total_bytes: int = 0, + first_chunk_ms: float | None = None, +) -> Dict[str, Any]: + request_total_ms = max(0.0, (finished_at - request_start) * 1000.0) + legacy_infer_ms = max(0.0, request_total_ms - pack_ms) + return { + "backend": backend, + "backend_mode": backend, + "fallback_reason": fallback_reason, + "request_total_ms": request_total_ms, + "prepare_ms": 0.0, + "queue_wait_ms": 0.0, + "prefill_ms": 0.0, + "merge_ms": 0.0, + "decode_ms": 0.0, + "finalize_wait_ms": 0.0, + "synth_ms": 0.0, + "pack_ms": pack_ms, + "worker_total_ms": legacy_infer_ms, + "request_other_ms": 0.0, + "legacy_infer_ms": legacy_infer_ms, + "sample_rate": int(sample_rate) if sample_rate is not None else None, + "audio_bytes": int(audio_bytes), + "chunk_count": int(chunk_count), + "stream_total_bytes": int(stream_total_bytes), + "first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms), + } + + +def build_scheduler_submit_profile( + *, + backend: str, + request_start: float, + response_ready_at: float, + audio_bytes: int, + sample_rate: int, + prepare_spec_build_ms: float, + prepare_wall_ms: float, + prepare_executor_queue_ms: float, + prepare_executor_run_ms: float, + prepare_profile_total_ms: float, + prepare_profile_wall_ms: float, + prepare_other_ms: float, + engine_policy_wait_ms: float, + api_after_prepare_ms: float, + api_wait_result_ms: float, + pack_ms: float, + response_overhead_ms: float, + worker_profile: Dict[str, Any], +) -> Dict[str, Any]: + worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0)) + request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0) + request_other_ms = max( + 0.0, + request_total_ms + - prepare_wall_ms + - engine_policy_wait_ms + - api_after_prepare_ms + - worker_total_ms + - api_wait_result_ms + - pack_ms, + ) + result = { + "backend": backend, + "backend_mode": backend, + "audio_bytes": int(audio_bytes), + "sample_rate": int(sample_rate), + "prepare_spec_build_ms": prepare_spec_build_ms, + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_executor_queue_ms": prepare_executor_queue_ms, + "prepare_executor_run_ms": prepare_executor_run_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile_wall_ms": prepare_profile_wall_ms, + "prepare_other_ms": prepare_other_ms, + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "api_after_prepare_ms": api_after_prepare_ms, + "api_wait_result_ms": api_wait_result_ms, + "pack_ms": pack_ms, + "response_overhead_ms": response_overhead_ms, + "request_total_ms": request_total_ms, + "request_other_ms": request_other_ms, + } + result.update({key: value for key, value in worker_profile.items()}) + return result + + +def format_ms_header(value: Any) -> str: + return f"{float(value):.3f}" + + +def build_scheduler_submit_headers( + *, + request_id: str, + media_type: str, + sample_rate: int, + profile: Dict[str, Any], +) -> Dict[str, str]: + prepare_profile = dict(profile.get("prepare_profile", {})) + headers = { + "X-Request-Id": request_id, + "X-Semantic-Len": str(int(profile.get("semantic_len", 0))), + "X-Finish-Reason": str(profile.get("finish_reason", "unknown")), + "X-Queue-Wait-Ms": format_ms_header(profile.get("queue_wait_ms", 0.0)), + "X-Decode-Admission-Wait-Ms": format_ms_header(profile.get("decode_admission_wait_ms", 0.0)), + "X-Engine-Policy-Wait-Ms": format_ms_header(profile.get("engine_policy_wait_ms", 0.0)), + "X-Engine-Dispatch-Wait-Ms": format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)), + "X-Prepare-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Wall-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)), + "X-Prepare-Spec-Build-Ms": format_ms_header(profile.get("prepare_spec_build_ms", 0.0)), + "X-Prepare-Executor-Queue-Ms": format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)), + "X-Prepare-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)), + "X-Prepare-Executor-Run-Ms": format_ms_header(profile.get("prepare_executor_run_ms", 0.0)), + "X-Prepare-Profile-Total-Ms": format_ms_header(profile.get("prepare_profile_total_ms", 0.0)), + "X-Prepare-Profile-Wall-Ms": format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)), + "X-Prepare-Other-Ms": format_ms_header(profile.get("prepare_other_ms", 0.0)), + "X-Api-After-Prepare-Ms": format_ms_header(profile.get("api_after_prepare_ms", 0.0)), + "X-Prefill-Ms": format_ms_header(profile.get("prefill_ms", 0.0)), + "X-Merge-Ms": format_ms_header(profile.get("merge_ms", 0.0)), + "X-Decode-Ms": format_ms_header(profile.get("decode_ms", 0.0)), + "X-Finalize-Wait-Ms": format_ms_header(profile.get("finalize_wait_ms", 0.0)), + "X-Synth-Ms": format_ms_header(profile.get("synth_ms", 0.0)), + "X-Worker-Residual-Ms": format_ms_header(profile.get("worker_residual_ms", 0.0)), + "X-Worker-Other-Ms": format_ms_header(profile.get("worker_other_ms", 0.0)), + "X-Pack-Ms": format_ms_header(profile.get("pack_ms", 0.0)), + "X-Worker-Total-Ms": format_ms_header(profile.get("worker_total_ms", 0.0)), + "X-Api-Wait-Result-Ms": format_ms_header(profile.get("api_wait_result_ms", 0.0)), + "X-Decode-Steps": str(int(profile.get("decode_steps", 0))), + "X-Sample-Rate": str(int(sample_rate)), + "X-Response-Overhead-Ms": format_ms_header(profile.get("response_overhead_ms", 0.0)), + "X-Request-Other-Ms": format_ms_header(profile.get("request_other_ms", 0.0)), + "X-Request-Total-Ms": format_ms_header(profile.get("request_total_ms", 0.0)), + } + headers.update( + { + "X-Prepare-Prompt-Text-Ms": format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)), + "X-Prepare-Target-Text-Ms": format_ms_header(prepare_profile.get("text_features_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)), + "X-Prepare-Prompt-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)), + "X-Prepare-Target-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)), + "X-Prepare-Prompt-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)), + "X-Prepare-Target-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)), + "X-Prepare-Prompt-Bert-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)), + "X-Prepare-Prompt-Bert-Forward-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)), + "X-Prepare-Target-Bert-Forward-Ms": format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)), + "X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))), + "X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))), + "X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))), + "X-Prepare-Prompt-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)), + "X-Prepare-Target-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)), + "X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)), + "X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))), + "X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)), + "X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))), + "X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)), + "X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Wait-Ms": format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)), + "X-Prepare-Prompt-Semantic-CPU-Ms": format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)), + "X-Prepare-Prompt-Semantic-Forward-Ms": format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)), + "X-Prepare-Ref-Spec-Ms": format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)), + "X-Prepare-Ref-Spec-Wait-Ms": format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)), + "X-Prepare-Ref-Bundle-Ms": format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)), + "X-Prepare-Tensorize-Ms": format_ms_header(prepare_profile.get("tensorize_ms", 0.0)), + "X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))), + "X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))), + } + ) + return headers + + +def build_scheduler_debug_request_profile( + *, + state: T2SRequestState, + item: T2SFinishedItem, + batch_request_count: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + batch_request_total_ms: float, +) -> Dict[str, Any]: + prepare_profile = dict(state.prepare_profile) + prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0)) + return { + "backend": "scheduler_debug", + "backend_mode": "scheduler_debug", + "batch_request_count": int(batch_request_count), + "batch_prepare_wall_ms": float(prepare_batch_wall_ms), + "batch_decode_wall_ms": float(decode_batch_wall_ms), + "batch_request_total_ms": float(batch_request_total_ms), + "prepare_ms": prepare_wall_ms, + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)), + "prepare_profile": prepare_profile, + "decode_steps": int(item.finish_idx), + "finish_idx": int(item.finish_idx), + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_reason": item.finish_reason, + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + } + + +def build_scheduler_debug_batch_profile( + *, + request_count: int, + max_steps: int, + prepare_batch_wall_ms: float, + decode_batch_wall_ms: float, + request_total_ms: float, + finished_items: Sequence[T2SFinishedItem], +) -> Dict[str, Any]: + finish_reason_counts: Dict[str, int] = {} + total_semantic_len = 0 + for item in finished_items: + finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1 + total_semantic_len += int(item.semantic_tokens.shape[0]) + return { + "request_count": int(request_count), + "max_steps": int(max_steps), + "prepare_batch_wall_ms": float(prepare_batch_wall_ms), + "decode_batch_wall_ms": float(decode_batch_wall_ms), + "request_total_ms": float(request_total_ms), + "total_semantic_len": int(total_semantic_len), + "finish_reason_counts": finish_reason_counts, + } diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py new file mode 100644 index 00000000..14d59f12 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import NormalizedEngineRequest, ReferenceRegistry + + +def normalize_lang(value: str | None) -> str | None: + if value in [None, ""]: + return value + return str(value).lower() + + +def apply_default_reference(reference_registry: ReferenceRegistry, req: dict) -> dict: + normalized = dict(req) + default_ref = reference_registry.get_default() + if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]: + normalized["ref_audio_path"] = default_ref.ref_audio_path + if "text_lang" in normalized: + normalized["text_lang"] = normalize_lang(normalized.get("text_lang")) + if "prompt_lang" in normalized: + normalized["prompt_lang"] = normalize_lang(normalized.get("prompt_lang")) + return normalized + + +def check_params(tts: TTS, cut_method_names: Sequence[str], req: dict) -> Optional[str]: + text = req.get("text", "") + text_lang = req.get("text_lang", "") + ref_audio_path = req.get("ref_audio_path", "") + media_type = req.get("media_type", "wav") + prompt_lang = req.get("prompt_lang", "") + text_split_method = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return "ref_audio_path is required" + if text in [None, ""]: + return "text is required" + if text_lang in [None, ""]: + return "text_lang is required" + if text_lang.lower() not in tts.configs.languages: + return f"text_lang: {text_lang} is not supported in version {tts.configs.version}" + if prompt_lang in [None, ""]: + return "prompt_lang is required" + if prompt_lang.lower() not in tts.configs.languages: + return f"prompt_lang: {prompt_lang} is not supported in version {tts.configs.version}" + if media_type not in ["wav", "raw", "ogg", "aac"]: + return f"media_type: {media_type} is not supported" + if text_split_method not in cut_method_names: + return f"text_split_method:{text_split_method} is not supported" + return None + + +def base_request_defaults() -> Dict[str, Any]: + return { + "request_id": None, + "text": None, + "text_lang": None, + "ref_audio_path": None, + "aux_ref_audio_paths": None, + "prompt_text": "", + "prompt_lang": None, + "top_k": 15, + "top_p": 1.0, + "temperature": 1.0, + "text_split_method": "cut5", + "batch_size": 1, + "batch_threshold": 0.75, + "speed_factor": 1.0, + "split_bucket": False, + "fragment_interval": 0.3, + "seed": -1, + "media_type": "wav", + "streaming_mode": False, + "return_fragment": False, + "fixed_length_chunk": False, + "response_streaming": False, + "parallel_infer": False, + "repetition_penalty": 1.35, + "sample_steps": 32, + "super_sampling": False, + "overlap_length": 2, + "min_chunk_length": 16, + "early_stop_num": -1, + "ready_step": 0, + "timeout_sec": None, + } + + +def normalize_streaming_mode(req: dict) -> dict: + normalized = dict(req) + streaming_mode = normalized.get("streaming_mode", False) + return_fragment = normalized.get("return_fragment", False) + if streaming_mode is False: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 0: + normalized["streaming_mode"] = False + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 1 or streaming_mode is True: + normalized["streaming_mode"] = False + normalized["return_fragment"] = True + normalized["fixed_length_chunk"] = False + elif streaming_mode == 2: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = False + elif streaming_mode == 3: + normalized["streaming_mode"] = True + normalized["return_fragment"] = False + normalized["fixed_length_chunk"] = True + else: + raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)") + normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment) + return normalized + + +def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool: + return aux_ref_audio_paths not in [None, [], ()] + + +def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]: + return "scheduler_v1_direct", None + + +def normalize_engine_request( + *, + tts: TTS, + cut_method_names: Sequence[str], + reference_registry: ReferenceRegistry, + payload: dict | NormalizedEngineRequest, + request_id: str | None = None, + normalize_streaming: bool = False, + error_prefix: str = "request 参数非法: ", +) -> NormalizedEngineRequest: + if isinstance(payload, NormalizedEngineRequest): + normalized_payload = payload.to_payload() + else: + normalized_payload = base_request_defaults() + normalized_payload.update(dict(payload)) + if request_id not in [None, ""]: + normalized_payload["request_id"] = str(request_id) + elif normalized_payload.get("request_id") in [None, ""]: + raise ValueError("request_id is required after normalization") + normalized_payload = apply_default_reference(reference_registry, normalized_payload) + if normalize_streaming: + normalized_payload = normalize_streaming_mode(normalized_payload) + error = check_params(tts, cut_method_names, normalized_payload) + if error is not None: + raise ValueError(f"{error_prefix}{error}") + timeout_sec = normalized_payload.get("timeout_sec") + parsed_timeout = None if timeout_sec in [None, ""] else float(timeout_sec) + aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths") + normalized_aux_ref_audio_paths = None if aux_ref_audio_paths in [None, "", []] else [str(item) for item in aux_ref_audio_paths] + return NormalizedEngineRequest( + request_id=str(normalized_payload["request_id"]), + text=str(normalized_payload["text"]), + text_lang=str(normalized_payload["text_lang"]), + ref_audio_path=str(normalized_payload["ref_audio_path"]), + prompt_lang=str(normalized_payload["prompt_lang"]), + prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")), + aux_ref_audio_paths=normalized_aux_ref_audio_paths, + top_k=int(normalized_payload["top_k"]), + top_p=float(normalized_payload["top_p"]), + temperature=float(normalized_payload["temperature"]), + repetition_penalty=float(normalized_payload["repetition_penalty"]), + early_stop_num=int(normalized_payload.get("early_stop_num", -1)), + ready_step=int(normalized_payload.get("ready_step", 0)), + text_split_method=str(normalized_payload["text_split_method"]), + batch_size=int(normalized_payload["batch_size"]), + batch_threshold=float(normalized_payload["batch_threshold"]), + split_bucket=bool(normalized_payload["split_bucket"]), + speed_factor=float(normalized_payload["speed_factor"]), + fragment_interval=float(normalized_payload["fragment_interval"]), + seed=int(normalized_payload["seed"]), + media_type=str(normalized_payload["media_type"]), + streaming_mode=normalized_payload["streaming_mode"], + return_fragment=bool(normalized_payload.get("return_fragment", False)), + fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)), + response_streaming=bool(normalized_payload.get("response_streaming", False)), + parallel_infer=bool(normalized_payload["parallel_infer"]), + sample_steps=int(normalized_payload["sample_steps"]), + super_sampling=bool(normalized_payload["super_sampling"]), + overlap_length=int(normalized_payload["overlap_length"]), + min_chunk_length=int(normalized_payload["min_chunk_length"]), + timeout_sec=parsed_timeout, + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py new file mode 100644 index 00000000..cf6677fb --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +import asyncio +import time +import uuid +from io import BytesIO +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution + + +class EngineApiSchedulerFlow: + def __init__(self, api: Any) -> None: + self.api = api + + def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]: + specs: List[SchedulerRequestSpec] = [] + for index, payload in enumerate(request_items): + normalized = self.api._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"req_{index:03d}"), + error_prefix=f"request[{index}] 参数非法: ", + ) + specs.append(normalized.to_scheduler_spec()) + return specs + + def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec: + normalized = self.api._normalize_engine_request( + payload, + request_id=( + payload.request_id + if isinstance(payload, NormalizedEngineRequest) + else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}") + ), + ) + return normalized.to_scheduler_spec() + + @staticmethod + def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + @staticmethod + def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + request_start = time.perf_counter() + set_scheduler_seed(seed) + normalized_requests: List[NormalizedEngineRequest] = [] + for index, payload in enumerate(request_items): + normalized_requests.append( + self.api._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"req_{index:03d}"), + error_prefix=f"request[{index}] 参数非法: ", + ) + ) + specs = [normalized.to_scheduler_spec() for normalized in normalized_requests] + request_ids = [normalized.request_id for normalized in normalized_requests] + for normalized, spec in zip(normalized_requests, specs): + self.api._register_request_state( + request_id=normalized.request_id, + api_mode="scheduler_debug", + backend="scheduler_debug", + media_type=normalized.media_type, + response_streaming=False, + meta=self.api._build_request_meta(normalized.to_payload()), + ) + self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"}) + self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None) + prepare_started_at = time.perf_counter() + original_worker_max_steps = int(self.api.scheduler_worker.max_steps) + original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps) + try: + self.api.scheduler_worker.max_steps = int(max_steps) + self.api.scheduler_worker.decode_executor.max_steps = int(max_steps) + prepared_payloads = await asyncio.gather( + *[ + self.api._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=time.perf_counter(), + engine_request_id=normalized.request_id, + ) + for normalized, spec in zip(normalized_requests, specs) + ] + ) + except Exception as exc: + for request_id in request_ids: + self.api._fail_request_state(request_id, str(exc)) + raise + finally: + self.api.scheduler_worker.max_steps = int(original_worker_max_steps) + self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps) + prepare_finished_at = time.perf_counter() + prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0) + states = [payload[0] for payload in prepared_payloads] + for state in states: + self.api._update_request_state( + state.request_id, + EngineStatus.READY_FOR_PREFILL, + { + "prepare_profile": dict(state.prepare_profile), + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + }, + ) + decode_started_at = time.perf_counter() + try: + loop = asyncio.get_running_loop() + done_futures: List[asyncio.Future] = [] + for normalized, state in zip(normalized_requests, states): + done_future = loop.create_future() + done_futures.append(done_future) + await self.api._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + super_sampling=bool(normalized.super_sampling), + prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)), + prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)), + done_loop=loop, + done_future=done_future, + engine_request_id=normalized.request_id, + timeout_sec=normalized.timeout_sec, + ) + timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]] + timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0 + jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec))) + except Exception as exc: + for request_id in request_ids: + self.api._fail_request_state(request_id, str(exc)) + raise + decode_finished_at = time.perf_counter() + decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0) + request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0) + request_profiles: List[Dict[str, Any]] = [] + finished: List[Dict[str, Any]] = [] + finish_reason_counts: Dict[str, int] = {} + total_semantic_len = 0 + for state, job in zip(states, jobs): + if job.error is not None: + self.api._fail_request_state(state.request_id, str(job.error)) + raise RuntimeError(str(job.error)) + if job.result is None: + self.api._fail_request_state(state.request_id, "scheduler_debug finished without result") + raise RuntimeError(f"{state.request_id} finished without result") + job_result = dict(job.result) + request_profile = { + **job_result, + "backend": "scheduler_debug", + "backend_mode": "scheduler_debug", + "batch_request_count": int(len(states)), + "batch_prepare_wall_ms": float(prepare_batch_wall_ms), + "batch_decode_wall_ms": float(decode_batch_wall_ms), + "batch_request_total_ms": float(request_total_ms), + "prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)), + "prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)), + "prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)), + "prepare_profile": dict(state.prepare_profile), + "norm_text": state.norm_text, + "norm_prompt_text": state.norm_prompt_text, + } + request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)}) + self.api._merge_request_state_profile(state.request_id, request_profile) + semantic_len = int(job_result.get("semantic_len", 0)) + finish_reason = str(job_result.get("finish_reason", "unknown")) + finished.append( + { + "request_id": state.request_id, + "semantic_len": semantic_len, + "finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))), + "finish_reason": finish_reason, + } + ) + finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1 + total_semantic_len += semantic_len + return SchedulerDebugExecution( + payload={ + "message": "success", + "request_count": len(states), + "max_steps": int(max_steps), + "batch_profile": { + "request_count": int(len(states)), + "max_steps": int(max_steps), + "prepare_batch_wall_ms": float(prepare_batch_wall_ms), + "decode_batch_wall_ms": float(decode_batch_wall_ms), + "request_total_ms": float(request_total_ms), + "total_semantic_len": int(total_semantic_len), + "finish_reason_counts": finish_reason_counts, + }, + "requests": self._summarize_scheduler_states(states), + "finished": finished, + "request_profiles": request_profiles, + "request_traces": self.api._collect_request_summaries(request_ids), + } + ) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + request_start = time.perf_counter() + prepare_start = request_start + normalized = self.api._normalize_engine_request( + payload, + request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"), + ) + spec = self._build_scheduler_submit_spec(normalized) + deadline_ts = None + timeout_sec = normalized.timeout_sec + if timeout_sec is not None: + try: + deadline_ts = request_start + float(timeout_sec) + except Exception: + deadline_ts = None + self.api._register_request_state( + request_id=spec.request_id, + api_mode="scheduler_submit", + backend="scheduler_v1", + media_type=normalized.media_type, + response_streaming=False, + deadline_ts=deadline_ts, + meta=self.api._build_request_meta(normalized.to_payload()), + ) + self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"}) + spec_ready_at = time.perf_counter() + prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0) + self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms}) + try: + state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=spec_ready_at, + engine_request_id=spec.request_id, + ) + except Exception as exc: + self.api._fail_request_state(spec.request_id, str(exc)) + raise + prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0) + prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0) + prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0) + prepare_profile = dict(state.prepare_profile) + prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms)) + prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms) + self.api._update_request_state( + spec.request_id, + EngineStatus.READY_FOR_PREFILL, + { + "prepare_wall_ms": prepare_wall_ms, + "prepare_profile_total_ms": prepare_profile_total_ms, + "prepare_profile": prepare_profile, + }, + ) + api_after_prepare_start = time.perf_counter() + loop = asyncio.get_running_loop() + done_future = loop.create_future() + await self.api._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=float(normalized.speed_factor), + sample_steps=int(normalized.sample_steps), + media_type=normalized.media_type, + super_sampling=bool(normalized.super_sampling), + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=loop, + done_future=done_future, + engine_request_id=spec.request_id, + timeout_sec=normalized.timeout_sec, + ) + api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0) + try: + job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)) + except Exception as exc: + self.api._fail_request_state(spec.request_id, str(exc)) + raise + wait_return_at = time.perf_counter() + if job.error is not None: + raise RuntimeError(job.error) + if job.audio_data is None or job.sample_rate is None or job.result is None: + self.api._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result") + raise RuntimeError(f"{job.request_id} finished without audio result") + pack_start = time.perf_counter() + audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue() + pack_end = time.perf_counter() + pack_ms = (pack_end - pack_start) * 1000.0 + api_wait_result_ms = 0.0 + if job.result_ready_time is not None: + api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0) + response_ready_at = time.perf_counter() + response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0) + submit_profile = self.api._build_scheduler_submit_profile( + backend="scheduler_v1", + request_start=request_start, + response_ready_at=response_ready_at, + audio_bytes=len(audio_data), + sample_rate=int(job.sample_rate), + prepare_spec_build_ms=prepare_spec_build_ms, + prepare_wall_ms=prepare_wall_ms, + prepare_executor_queue_ms=prepare_executor_queue_ms, + prepare_executor_run_ms=prepare_executor_run_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + prepare_profile_wall_ms=prepare_profile_wall_ms, + prepare_other_ms=prepare_other_ms, + engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)), + api_after_prepare_ms=api_after_prepare_ms, + api_wait_result_ms=api_wait_result_ms, + pack_ms=pack_ms, + response_overhead_ms=response_overhead_ms, + worker_profile=dict(job.result or {}), + ) + headers = self.api._build_scheduler_submit_headers( + request_id=job.request_id, + media_type=job.media_type, + sample_rate=int(job.sample_rate), + profile=submit_profile, + ) + self.api._merge_request_state_profile( + spec.request_id, + dict(submit_profile, response_headers_emitted=True), + ) + return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=str(job.media_type), headers=headers) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py new file mode 100644 index 00000000..5c3bd7a5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import subprocess +import threading +import wave +from io import BytesIO + +import numpy as np +import soundfile as sf +import torch + + +def set_scheduler_seed(seed: int): + if seed in ["", None]: + return + seed = int(seed) + if seed < 0: + return + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + def handle_pack_ogg(): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except (RuntimeError, ValueError): + handle_pack_ogg() + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", + "-ar", + str(rate), + "-ac", + "1", + "-i", + "pipe:0", + "-c:a", + "aac", + "-b:a", + "192k", + "-vn", + "-f", + "adts", + "pipe:1", + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + wav_buf.seek(0) + return wav_buf.read() + + diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py new file mode 100644 index 00000000..d7740a52 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Any + +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_registry import EngineRegistryBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_runtime import EngineRuntimeBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_stage import EngineStageBridgeFacade + + +class EngineBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + self.registry_bridge = EngineRegistryBridgeFacade(owner) + self.stage_bridge = EngineStageBridgeFacade(owner) + self.runtime_bridge = EngineRuntimeBridgeFacade(owner) + + def __getattr__(self, name: str) -> Any: + for component in (self.registry_bridge, self.stage_bridge, self.runtime_bridge): + if hasattr(component, name): + return getattr(component, name) + raise AttributeError(name) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py new file mode 100644 index 00000000..e2044ec4 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineRequestState, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineBridgeDelegates: + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.bridge_facade._register_request_state( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._update_request_state(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._merge_request_state_profile(request_id, extra) + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_prepare_state() + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_finalize_state() + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_dispatch_state() + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._register_engine_job(job) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._get_engine_job(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.bridge_facade._pop_engine_job(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_job_registry() + + def _is_engine_drained(self) -> bool: + return self.bridge_facade._is_engine_drained() + + def _record_engine_job_done(self, request_id: str) -> None: + self.bridge_facade._record_engine_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + self.bridge_facade._fail_engine_jobs(request_ids, error) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s) + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s) + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + self.bridge_facade._enqueue_engine_finished_items(items) + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_pending_queue_state() + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineBridgeFacade._summarize_active_batch(active_batch) + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.bridge_facade._refresh_engine_decode_runtime_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + self.bridge_facade._update_engine_decode_runtime_state(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_decode_runtime_state() + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_engine_arbiter_state() + + def _notify_engine_arbiter(self) -> None: + self.bridge_facade._notify_engine_arbiter() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.bridge_facade._enqueue_engine_decode_pending_job(job) + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.bridge_facade._peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.bridge_facade._engine_has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.bridge_facade._prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.bridge_facade._enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.bridge_facade._take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + super_sampling: bool, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.bridge_facade._enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + super_sampling=super_sampling, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + return self.bridge_facade._select_engine_stage() + + def _run_engine_prepare_once(self) -> bool: + return self.bridge_facade._run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.bridge_facade._run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.bridge_facade._run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + self.bridge_facade._run_engine_arbiter_loop() + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.bridge_facade._complete_request_state(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.bridge_facade._fail_request_state(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.bridge_facade._snapshot_request_registry() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py new file mode 100644 index 00000000..f07250e1 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineRegistryBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def request_registry(self): + return self.owner.request_registry + + @property + def engine_prepare_queue_owner(self): + return self.owner.engine_prepare_queue_owner + + @property + def engine_prepare_text_queue_owner(self): + return self.owner.engine_prepare_text_queue_owner + + @property + def engine_prepare_ref_spec_queue_owner(self): + return self.owner.engine_prepare_ref_spec_queue_owner + + @property + def engine_finalize_queue_owner(self): + return self.owner.engine_finalize_queue_owner + + @property + def engine_dispatch_queue_owner(self): + return self.owner.engine_dispatch_queue_owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_job_registry(self): + return self.owner.engine_job_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + def _register_request_state( + self, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + return self.request_registry.register( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=response_streaming, + deadline_ts=deadline_ts, + meta=meta, + ) + + def _update_request_state( + self, + request_id: str, + status: str, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + self.request_registry.update(request_id, status, extra) + + def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.merge_profile(request_id, extra) + + def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + self.request_registry.complete(request_id, extra) + + def _fail_request_state(self, request_id: str, error: str) -> None: + self.request_registry.fail(request_id, error) + + def _snapshot_request_registry(self) -> Dict[str, Any]: + return self.request_registry.snapshot() + + def _snapshot_engine_prepare_state(self) -> Dict[str, Any]: + audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16) + text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16) + ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16) + return { + "waiting_count": int(audio_snapshot.get("waiting_count", 0)) + + int(text_snapshot.get("waiting_count", 0)) + + int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)), + "text_waiting_count": int(text_snapshot.get("waiting_count", 0)), + "ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)), + "audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])), + "text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])), + "ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])), + "peak_waiting": int( + max( + int(audio_snapshot.get("peak_waiting", 0)), + int(text_snapshot.get("peak_waiting", 0)), + int(ref_spec_snapshot.get("peak_waiting", 0)), + ) + ), + "total_submitted": int(audio_snapshot.get("total_submitted", 0)), + "total_completed": int(audio_snapshot.get("total_completed", 0)), + "text_total_submitted": int(text_snapshot.get("total_submitted", 0)), + "text_total_completed": int(text_snapshot.get("total_completed", 0)), + "ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)), + "ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)), + } + + def _snapshot_engine_finalize_state(self) -> Dict[str, Any]: + return self.engine_finalize_queue_owner.snapshot(max_request_ids=16) + + def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]: + return self.engine_dispatch_queue_owner.snapshot( + max_request_ids=16, + extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})}, + ) + + def _register_engine_job(self, job: SchedulerPendingJob) -> None: + self.engine_job_registry.register(job, keep_job=True) + + def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.get(request_id) + + def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None: + return self.engine_job_registry.pop(request_id) + + def _snapshot_engine_job_registry(self) -> Dict[str, Any]: + return self.engine_job_registry.snapshot(max_request_ids=32) + + def _is_engine_drained(self) -> bool: + prepare_empty = self.engine_prepare_queue_owner.is_drained() + prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained() + prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained() + dispatch_empty = self.engine_dispatch_queue_owner.is_drained() + finalize_empty = self.engine_finalize_queue_owner.is_drained() + decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs() + job_empty = self.engine_job_registry.is_empty() + worker_state = self.scheduler_worker.snapshot() + return bool( + prepare_empty + and prepare_text_empty + and prepare_ref_spec_empty + and dispatch_empty + and finalize_empty + and decode_pending_empty + and job_empty + and self.engine_decode_runtime_owner.get_active_batch() is None + and int(worker_state.get("prepare_inflight", 0)) <= 0 + and int(worker_state.get("finalize_inflight", 0)) <= 0 + and int(worker_state.get("finalize_pending", 0)) <= 0 + ) + + def _record_engine_job_done(self, request_id: str) -> None: + self.engine_job_registry.mark_finished_and_remove(request_id) + self.scheduler_worker.record_external_job_done(request_id) + + def _complete_engine_job( + self, + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + completion_bridge = self.scheduler_worker.completion_bridge + completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + completion_bridge.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate), + on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid), + ) + + def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None: + if not request_ids: + return + completion_bridge = self.scheduler_worker.completion_bridge + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + completion_bridge.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid), + ) + + def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for job in jobs: + job.prefill_ms += delta_ms + + def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + activate_request_ids: List[str] = [] + for request_id in request_ids: + job = self._get_engine_job(request_id) + if job is None: + continue + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items] + self.owner.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py new file mode 100644 index 00000000..47be8b67 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner + + +class EngineRuntimeBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch) + + def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]: + return self.engine_policy_arbiter.snapshot_state() + + def _notify_engine_arbiter(self) -> None: + self.engine_policy_arbiter.notify() + + def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed) + + def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage() + self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot) + return stage, reason, policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py new file mode 100644 index 00000000..2a52e779 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineStageBridgeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_stage_coordinator(self): + return self.owner.engine_stage_coordinator + + def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_pending_queue_state() + + def _refresh_engine_decode_runtime_state(self, last_event: str) -> None: + self.engine_decode_runtime_owner.refresh_state(last_event) + + def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + if self.scheduler_worker.is_engine_decode_control_enabled(): + return + self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot) + + def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]: + return self.engine_decode_runtime_owner.snapshot_state() + + def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None: + self.engine_decode_runtime_owner.enqueue_pending_job(job) + self.owner.engine_policy_arbiter.notify() + + def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch) + + def _peek_queue_age_ms(self, queue_name: str) -> float: + return self.engine_stage_coordinator.peek_queue_age_ms(queue_name) + + def _engine_has_pending_work(self) -> bool: + return self.engine_stage_coordinator.has_pending_work() + + async def _prepare_state_via_engine_gpu_queue( + self, + *, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks) + + def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking() + + async def _enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + super_sampling: bool, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + super_sampling=super_sampling, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def _run_engine_prepare_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_prepare_once() + + def _run_engine_finalize_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_finalize_once() + + def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state) + + def _run_engine_decode_runtime_once(self) -> bool: + return self.engine_stage_coordinator.run_engine_decode_runtime_once() + + def _run_engine_arbiter_loop(self) -> None: + return self.engine_stage_coordinator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py new file mode 100644 index 00000000..2cb1e175 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import os +import threading +from typing import Any + +from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineArbiterConfig, + EngineDecodeRuntimeOwner, + EnginePolicyArbiterController, + EnginePolicyConfig, + EngineRequestRegistry, + EngineTaskQueueOwner, + ModelRegistry, + ReferenceRegistry, + RuntimeStateCallbacks, + SchedulerJobRegistry, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineCompositionBuilder: + def __init__(self, owner: Any) -> None: + self.owner = owner + + def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None: + self._init_registries_and_locks() + self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms) + self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms) + self._init_runtime_owners() + self._init_stage_coordinator() + self._init_arbiter() + self._init_facades() + self._start_arbiter_thread() + + def _init_registries_and_locks(self) -> None: + owner = self.owner + owner.reference_registry = ReferenceRegistry() + owner.model_registry = ModelRegistry( + t2s_weights_path=str(owner.tts.configs.t2s_weights_path), + vits_weights_path=str(owner.tts.configs.vits_weights_path), + ) + owner.request_registry = EngineRequestRegistry( + recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64"))) + ) + owner.engine_job_registry = SchedulerJobRegistry(threading.Lock()) + owner.direct_tts_lock = threading.RLock() + owner.management_lock = threading.RLock() + owner.engine_dispatch_last_snapshot = {} + + def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None: + owner = self.owner + owner.scheduler_worker = UnifiedSchedulerWorker( + owner.tts, + max_steps=max_steps, + micro_batch_wait_ms=micro_batch_wait_ms, + runtime_callbacks=RuntimeStateCallbacks( + update=owner._update_request_state, + complete=owner._complete_request_state, + fail=owner._fail_request_state, + decode_runtime_update=owner._update_engine_decode_runtime_state, + ), + external_finalize_submit=owner._enqueue_worker_finished_for_finalize, + ) + + def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None: + owner = self.owner + worker_capacity_limits = owner.scheduler_worker.get_capacity_limits() + prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight()) + owner.engine_policy_config = EnginePolicyConfig( + enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True), + poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_backlog_soft_max=max( + 0, + owner._env_int( + "GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX", + int(worker_capacity_limits["decode_backlog_max"]), + ), + ), + finalize_pending_soft_max=max( + 0, + owner._env_int( + "GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX", + int(worker_capacity_limits["finalize_pending_max"]), + ), + ), + prepare_inflight_soft_max=max( + 0, + owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight), + ), + active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)), + ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)), + active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)), + ) + owner.engine_arbiter_config = EngineArbiterConfig( + poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))), + decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)), + prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)), + finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)), + ) + + def _init_runtime_owners(self) -> None: + owner = self.owner + owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner( + get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters, + get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s, + ) + owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed") + owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched") + + def _init_stage_coordinator(self) -> None: + owner = self.owner + owner.engine_stage_coordinator = EngineStageCoordinator( + tts=owner.tts, + scheduler_worker=owner.scheduler_worker, + prepare_queue_owner=owner.engine_prepare_queue_owner, + prepare_text_queue_owner=owner.engine_prepare_text_queue_owner, + prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner, + finalize_queue_owner=owner.engine_finalize_queue_owner, + dispatch_queue_owner=owner.engine_dispatch_queue_owner, + decode_runtime_owner=owner.engine_decode_runtime_owner, + update_request_state=owner._update_request_state, + merge_request_state_profile=owner._merge_request_state_profile, + fail_request_state=owner._fail_request_state, + get_engine_job=owner._get_engine_job, + register_engine_job=owner._register_engine_job, + fail_engine_jobs=owner._fail_engine_jobs, + complete_engine_job=owner._complete_engine_job, + add_engine_prefill_time=owner._add_engine_prefill_time, + add_engine_merge_time=owner._add_engine_merge_time, + add_engine_decode_time=owner._add_engine_decode_time, + enqueue_engine_finished_items=owner._enqueue_engine_finished_items, + snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state, + snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state, + ) + + def _init_arbiter(self) -> None: + owner = self.owner + owner.engine_policy_arbiter = EnginePolicyArbiterController( + policy_config=owner.engine_policy_config, + arbiter_config=owner.engine_arbiter_config, + snapshot_request_registry=owner._snapshot_request_registry, + get_worker_state=owner.scheduler_worker.snapshot, + snapshot_prepare_state=owner._snapshot_engine_prepare_state, + snapshot_finalize_state=owner._snapshot_engine_finalize_state, + snapshot_dispatch_state=owner._snapshot_engine_dispatch_state, + snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state, + snapshot_job_registry=owner._snapshot_engine_job_registry, + peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms, + merge_request_state_profile=owner._merge_request_state_profile, + ) + owner.engine_stage_coordinator.bind_arbiter( + notify_arbiter=owner._notify_engine_arbiter, + select_stage=owner._select_engine_stage, + mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick( + stage=stage, + reason=reason, + policy_allowed=policy_allowed, + ), + wait_arbiter=owner.engine_policy_arbiter.wait, + ) + + def _init_facades(self) -> None: + owner = self.owner + owner.bridge_facade = EngineBridgeFacade(owner) + owner.api_facade = EngineApiFacade(owner) + owner.runtime_facade = EngineRuntimeFacade(owner) + + def _start_arbiter_thread(self) -> None: + owner = self.owner + owner.engine_arbiter_thread = threading.Thread( + target=owner._run_engine_arbiter_loop, + name="unified-engine-arbiter", + daemon=True, + ) + owner.engine_arbiter_thread.start() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py new file mode 100644 index 00000000..7b5ea5f8 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Generator, List, Optional + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec + + +@dataclass +class RuntimeControlCallbacks: + restart: Callable[[], None] | None = None + exit: Callable[[], None] | None = None + + +@dataclass +class DirectTTSExecution: + media_type: str + streaming: bool + audio_generator: Optional[Generator[bytes, None, None]] = None + audio_bytes: Optional[bytes] = None + request_id: Optional[str] = None + + +@dataclass +class NormalizedEngineRequest: + request_id: str + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + aux_ref_audio_paths: List[str] | None = None + top_k: int = 15 + top_p: float = 1.0 + temperature: float = 1.0 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = False + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool | int = False + return_fragment: bool = False + fixed_length_chunk: bool = False + response_streaming: bool = False + parallel_infer: bool = False + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + timeout_sec: float | None = None + + def to_payload(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "text": self.text, + "text_lang": self.text_lang, + "ref_audio_path": self.ref_audio_path, + "aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None, + "prompt_text": self.prompt_text, + "prompt_lang": self.prompt_lang, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "text_split_method": self.text_split_method, + "batch_size": self.batch_size, + "batch_threshold": self.batch_threshold, + "speed_factor": self.speed_factor, + "split_bucket": self.split_bucket, + "fragment_interval": self.fragment_interval, + "seed": self.seed, + "media_type": self.media_type, + "streaming_mode": self.streaming_mode, + "return_fragment": self.return_fragment, + "fixed_length_chunk": self.fixed_length_chunk, + "response_streaming": self.response_streaming, + "parallel_infer": self.parallel_infer, + "repetition_penalty": self.repetition_penalty, + "sample_steps": self.sample_steps, + "super_sampling": self.super_sampling, + "overlap_length": self.overlap_length, + "min_chunk_length": self.min_chunk_length, + "early_stop_num": self.early_stop_num, + "ready_step": self.ready_step, + "timeout_sec": self.timeout_sec, + } + + def to_scheduler_spec(self) -> SchedulerRequestSpec: + return SchedulerRequestSpec( + request_id=self.request_id, + ref_audio_path=Path(self.ref_audio_path), + prompt_text=self.prompt_text, + prompt_lang=self.prompt_lang, + text=self.text, + text_lang=self.text_lang, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + early_stop_num=self.early_stop_num, + aux_ref_audio_paths=list(self.aux_ref_audio_paths or []), + ready_step=self.ready_step, + ) + + +@dataclass +class SchedulerDebugExecution: + payload: Dict[str, Any] + + +@dataclass +class SchedulerSubmitExecution: + audio_bytes: bytes + media_type: str + headers: Dict[str, str] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py new file mode 100644 index 00000000..65953dd9 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus + + +@dataclass +class EnginePolicyConfig: + enabled: bool = True + poll_wait_ms: float = 5.0 + decode_backlog_soft_max: int = 0 + finalize_pending_soft_max: int = 0 + prepare_inflight_soft_max: int = 0 + active_decode_soft_max: int = 0 + ready_for_prefill_soft_max: int = 0 + active_request_soft_max: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": bool(self.enabled), + "poll_wait_ms": float(self.poll_wait_ms), + "decode_backlog_soft_max": int(self.decode_backlog_soft_max), + "finalize_pending_soft_max": int(self.finalize_pending_soft_max), + "prepare_inflight_soft_max": int(self.prepare_inflight_soft_max), + "active_decode_soft_max": int(self.active_decode_soft_max), + "ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max), + "active_request_soft_max": int(self.active_request_soft_max), + } + + +@dataclass +class EngineArbiterConfig: + poll_wait_ms: float = 5.0 + decode_burst: int = 4 + prepare_aging_ms: float = 10.0 + finalize_aging_ms: float = 10.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "poll_wait_ms": float(self.poll_wait_ms), + "decode_burst": int(self.decode_burst), + "prepare_aging_ms": float(self.prepare_aging_ms), + "finalize_aging_ms": float(self.finalize_aging_ms), + } + + +@dataclass +class EngineArbiterState: + total_ticks: int = 0 + total_idle_ticks: int = 0 + total_prepare_dispatches: int = 0 + total_decode_dispatches: int = 0 + total_decode_runtime_ticks: int = 0 + total_finalize_dispatches: int = 0 + decode_budget_remaining: int = 0 + last_stage: str = "idle" + last_reason: str = "init" + last_observed_at: float = 0.0 + last_policy_allowed: bool = True + + +class EnginePolicyArbiterController: + def __init__( + self, + *, + policy_config: EnginePolicyConfig, + arbiter_config: EngineArbiterConfig, + snapshot_request_registry: Callable[[], Dict[str, Any]], + get_worker_state: Callable[[], Dict[str, Any]], + snapshot_prepare_state: Callable[[], Dict[str, Any]], + snapshot_finalize_state: Callable[[], Dict[str, Any]], + snapshot_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_decode_runtime_state: Callable[[], Dict[str, Any]], + snapshot_job_registry: Callable[[], Dict[str, Any]], + peek_queue_age_ms: Callable[[str], float], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + ) -> None: + self.policy_config = policy_config + self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0) + self.arbiter_config = arbiter_config + self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0) + self.condition = threading.Condition() + self.state = EngineArbiterState( + decode_budget_remaining=int(self.arbiter_config.decode_burst), + last_observed_at=time.perf_counter(), + ) + self.snapshot_request_registry = snapshot_request_registry + self.get_worker_state = get_worker_state + self.snapshot_prepare_state = snapshot_prepare_state + self.snapshot_finalize_state = snapshot_finalize_state + self.snapshot_dispatch_state = snapshot_dispatch_state + self.snapshot_decode_runtime_state = snapshot_decode_runtime_state + self.snapshot_job_registry = snapshot_job_registry + self.peek_queue_age_ms = peek_queue_age_ms + self.merge_request_state_profile = merge_request_state_profile + + def snapshot_state(self) -> Dict[str, Any]: + with self.condition: + return { + "config": self.arbiter_config.to_dict(), + "total_ticks": int(self.state.total_ticks), + "total_idle_ticks": int(self.state.total_idle_ticks), + "total_prepare_dispatches": int(self.state.total_prepare_dispatches), + "total_decode_dispatches": int(self.state.total_decode_dispatches), + "total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks), + "total_finalize_dispatches": int(self.state.total_finalize_dispatches), + "decode_budget_remaining": int(self.state.decode_budget_remaining), + "last_stage": str(self.state.last_stage), + "last_reason": str(self.state.last_reason), + "last_policy_allowed": bool(self.state.last_policy_allowed), + "last_observed_at": float(self.state.last_observed_at), + } + + def notify(self) -> None: + with self.condition: + self.condition.notify_all() + + def wait(self) -> None: + with self.condition: + self.condition.wait(timeout=self.arbiter_poll_s) + + def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None: + with self.condition: + self.state.total_ticks += 1 + if stage == "idle": + self.state.total_idle_ticks += 1 + elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}: + self.state.total_prepare_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "finalize": + self.state.total_finalize_dispatches += 1 + self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst) + elif stage == "decode_dispatch": + self.state.total_decode_dispatches += 1 + elif stage == "decode_runtime": + self.state.total_decode_runtime_ticks += 1 + self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1) + self.state.last_stage = str(stage) + self.state.last_reason = str(reason) + self.state.last_policy_allowed = bool(policy_allowed) + self.state.last_observed_at = time.perf_counter() + + def build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + prepare_dispatcher_state = self.snapshot_prepare_state() + finalize_dispatcher_state = self.snapshot_finalize_state() + dispatcher_state = self.snapshot_dispatch_state() + active_requests = list(request_registry.get("active_requests", [])) + status_counts: Dict[str, int] = {} + for item in active_requests: + status = str(item.get("status", "UNKNOWN")) + status_counts[status] = status_counts.get(status, 0) + 1 + + worker_pending_jobs = int(worker_state.get("pending_jobs", 0)) + worker_decode_active_size = int(worker_state.get("running_requests", 0)) + worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0)) + worker_finalize_pending = int(worker_state.get("finalize_pending", 0)) + worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0)) + engine_decode_runtime_state = self.snapshot_decode_runtime_state() + engine_job_registry = self.snapshot_job_registry() + decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0)) + decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0)) + return { + "active_request_count": int(len(active_requests)), + "status_counts": status_counts, + "queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)), + "cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)), + "gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)), + "ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)), + "active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)), + "ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)), + "finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)), + "streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)), + "worker_pending_jobs": worker_pending_jobs, + "worker_decode_active_size": worker_decode_active_size, + "worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)), + "worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)), + "engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs, + "engine_decode_runtime_active_request_count": decode_runtime_active_size, + "engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)), + "engine_job_registry_count": int(engine_job_registry.get("job_count", 0)), + "worker_prepare_inflight": worker_prepare_inflight, + "worker_finalize_pending": worker_finalize_pending, + "worker_finalize_inflight": worker_finalize_inflight, + "engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)), + "engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)), + "engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)), + "decode_backlog": int( + decode_runtime_pending_jobs + decode_runtime_active_size + if bool(worker_state.get("engine_decode_control_enabled", False)) + else worker_pending_jobs + worker_decode_active_size + ), + } + + def build_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self.build_stage_counters(request_registry, worker_state) + config = self.policy_config.to_dict() + blocked_reasons: List[Dict[str, Any]] = [] + finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0)) + limit_checks = [ + ("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])), + ("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])), + ("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])), + ("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])), + ("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])), + ("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])), + ] + if bool(config["enabled"]): + for name, value, limit in limit_checks: + if limit > 0 and int(value) >= int(limit): + blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)}) + return { + "enabled": bool(config["enabled"]), + "allowed": (not bool(config["enabled"])) or not blocked_reasons, + "blocked_reasons": blocked_reasons, + "config": config, + "metrics": { + "active_request_count": int(counters["active_request_count"]), + "queued_request_count": int(counters["queued_request_count"]), + "ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]), + "active_decode_request_count": int(counters["active_decode_request_count"]), + "engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]), + "engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]), + "decode_backlog": int(counters["decode_backlog"]), + "prepare_inflight": int(counters["worker_prepare_inflight"]), + "finalize_pending": int(finalize_pending_total), + "engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)), + "finalize_inflight": int(counters["worker_finalize_inflight"]), + }, + "observed_at": time.perf_counter(), + } + + async def wait_for_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if not self.policy_config.enabled: + return 0.0, snapshot + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + snapshot = self.build_policy_snapshot(request_registry, worker_state) + if snapshot["allowed"]: + wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0) + if request_id not in [None, ""]: + self.merge_request_state_profile( + str(request_id), + { + "engine_policy_wait_ms": float(wait_ms), + "engine_policy_snapshot": snapshot, + }, + ) + return wait_ms, snapshot + now = time.perf_counter() + if deadline is not None and now >= deadline: + blocked_summary = ", ".join( + f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", []) + ) + raise TimeoutError(f"engine policy admission timeout ({blocked_summary})") + await asyncio.sleep(self.policy_poll_s) + + def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]: + request_registry = self.snapshot_request_registry() + worker_state = self.get_worker_state() + policy_snapshot = self.build_policy_snapshot(request_registry, worker_state) + prepare_state = self.snapshot_prepare_state() + prepare_waiting = int(prepare_state.get("waiting_count", 0)) + prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0)) + prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0)) + prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0)) + finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0)) + decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0)) + decode_runtime_state = self.snapshot_decode_runtime_state() + worker_decode_has_work = bool(decode_runtime_state.get("has_work", False)) + worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False)) + worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0)) + worker_running_requests = int(decode_runtime_state.get("active_request_count", 0)) + prepare_age_ms = float(self.peek_queue_age_ms("prepare")) + prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio")) + prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text")) + prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec")) + finalize_age_ms = float(self.peek_queue_age_ms("finalize")) + decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending")) + decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0)) + policy_allowed = bool(policy_snapshot.get("allowed", True)) + if ( + worker_decode_control_enabled + and worker_decode_has_work + and policy_allowed + and decode_budget_remaining > 0 + and (worker_running_requests > 0 or worker_pending_jobs > 0) + ): + return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state + if ( + worker_decode_control_enabled + and worker_pending_jobs > 0 + and policy_allowed + and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms) + ): + return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state + if ( + decode_waiting > 0 + and policy_allowed + and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0) + ): + return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state + if ( + finalize_waiting > 0 + and prepare_ref_spec_waiting > 0 + and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0) + ): + return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state + if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms): + return "finalize", "finalize_aging", policy_snapshot, worker_state + if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0): + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0: + return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state + if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms): + if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6): + return "prepare_text", "prepare_aging", policy_snapshot, worker_state + if ( + prepare_ref_spec_waiting > 0 + and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6) + ): + return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state + return "prepare_audio", "prepare_aging", policy_snapshot, worker_state + if worker_decode_control_enabled and worker_decode_has_work and policy_allowed: + return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state + if decode_waiting > 0 and policy_allowed: + return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state + if finalize_waiting > 0: + return "finalize", "finalize_fallback", policy_snapshot, worker_state + if prepare_waiting > 0: + if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms): + return "prepare_text", "prepare_fallback", policy_snapshot, worker_state + if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0: + return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state + return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state + return "idle", "no_pending_work", policy_snapshot, worker_state diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py new file mode 100644 index 00000000..1aaa89c1 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py @@ -0,0 +1,382 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Deque, Dict, Optional, Sequence + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState + + +@dataclass +class DefaultReferenceState: + ref_audio_path: str | None = None + updated_at: float = 0.0 + + +class ReferenceRegistry: + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = DefaultReferenceState() + + def set_default(self, ref_audio_path: str) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time()) + return self._state + + def clear(self) -> DefaultReferenceState: + with self._lock: + self._state = DefaultReferenceState() + return self._state + + def get_default(self) -> DefaultReferenceState: + with self._lock: + return DefaultReferenceState( + ref_audio_path=self._state.ref_audio_path, + updated_at=self._state.updated_at, + ) + + +@dataclass +class ModelRegistryState: + t2s_weights_path: str + vits_weights_path: str + generation: int = 0 + t2s_generation: int = 0 + vits_generation: int = 0 + updated_at: float = field(default_factory=time.time) + + +class ModelRegistry: + def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None: + self._lock = threading.Lock() + self._state = ModelRegistryState( + t2s_weights_path=str(t2s_weights_path), + vits_weights_path=str(vits_weights_path), + ) + + def snapshot(self) -> ModelRegistryState: + with self._lock: + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.t2s_weights_path = str(weights_path) + self._state.generation += 1 + self._state.t2s_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + def mark_vits_reload(self, weights_path: str) -> ModelRegistryState: + with self._lock: + self._state.vits_weights_path = str(weights_path) + self._state.generation += 1 + self._state.vits_generation += 1 + self._state.updated_at = time.time() + return ModelRegistryState( + t2s_weights_path=self._state.t2s_weights_path, + vits_weights_path=self._state.vits_weights_path, + generation=self._state.generation, + t2s_generation=self._state.t2s_generation, + vits_generation=self._state.vits_generation, + updated_at=self._state.updated_at, + ) + + +class EngineStatus: + NEW = "NEW" + QUEUED = "QUEUED" + VALIDATED = "VALIDATED" + CPU_PREPARING = "CPU_PREPARING" + GPU_PREPARING = "GPU_PREPARING" + READY_FOR_PREFILL = "READY_FOR_PREFILL" + ACTIVE_DECODE = "ACTIVE_DECODE" + READY_FOR_FINALIZE = "READY_FOR_FINALIZE" + FINALIZING = "FINALIZING" + STREAMING = "STREAMING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +@dataclass +class EngineRequestState: + request_id: str + api_mode: str + backend: str + media_type: str + response_streaming: bool + submit_ts: float + deadline_ts: float | None = None + status: str = EngineStatus.NEW + updated_ts: float = 0.0 + error: str | None = None + finish_reason: str | None = None + meta: Dict[str, Any] = field(default_factory=dict) + profile: Dict[str, Any] = field(default_factory=dict) + lifecycle_timestamps: Dict[str, float] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "api_mode": self.api_mode, + "backend": self.backend, + "media_type": self.media_type, + "response_streaming": self.response_streaming, + "status": self.status, + "submit_ts": self.submit_ts, + "updated_ts": self.updated_ts, + "deadline_ts": self.deadline_ts, + "error": self.error, + "finish_reason": self.finish_reason, + "meta": dict(self.meta), + "profile": dict(self.profile), + "lifecycle_timestamps": dict(self.lifecycle_timestamps), + } + + +class EngineRequestRegistry: + def __init__(self, recent_limit: int) -> None: + self.lock = threading.Lock() + self.active_requests: Dict[str, EngineRequestState] = {} + self.recent_requests: Deque[EngineRequestState] = deque() + self.recent_limit = max(1, int(recent_limit)) + + def register( + self, + *, + request_id: str, + api_mode: str, + backend: str, + media_type: str, + response_streaming: bool, + deadline_ts: float | None = None, + meta: Optional[Dict[str, Any]] = None, + ) -> EngineRequestState: + now = time.perf_counter() + state = EngineRequestState( + request_id=request_id, + api_mode=api_mode, + backend=backend, + media_type=media_type, + response_streaming=bool(response_streaming), + submit_ts=now, + deadline_ts=deadline_ts, + updated_ts=now, + meta=dict(meta or {}), + lifecycle_timestamps={EngineStatus.NEW: now}, + ) + with self.lock: + self.active_requests[request_id] = state + return state + + def _move_to_recent_locked(self, state: EngineRequestState) -> None: + self.recent_requests.appendleft(state) + while len(self.recent_requests) > self.recent_limit: + self.recent_requests.pop() + + @staticmethod + def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None: + if not extra: + return + payload = dict(extra) + backend = payload.pop("backend", None) + if backend is not None: + state.backend = str(backend) + finish_reason = payload.pop("finish_reason", None) + if finish_reason is not None: + state.finish_reason = str(finish_reason) + error = payload.pop("error", None) + if error is not None: + state.error = str(error) + state.profile.update(payload) + + def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + return + state.status = str(status) + state.updated_ts = now + state.lifecycle_timestamps[str(status)] = now + self._apply_state_extra(state, extra) + + def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + if not extra: + return + now = time.perf_counter() + with self.lock: + state = self.active_requests.get(request_id) + if state is None: + for recent_state in self.recent_requests: + if recent_state.request_id == request_id: + state = recent_state + break + if state is None: + return + state.updated_ts = now + self._apply_state_extra(state, extra) + + def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.COMPLETED + state.updated_ts = now + state.lifecycle_timestamps[EngineStatus.COMPLETED] = now + self._apply_state_extra(state, extra) + self._move_to_recent_locked(state) + + def fail(self, request_id: str, error: str) -> None: + now = time.perf_counter() + with self.lock: + state = self.active_requests.pop(request_id, None) + if state is None: + return + state.status = EngineStatus.FAILED + state.updated_ts = now + state.error = str(error) + state.lifecycle_timestamps[EngineStatus.FAILED] = now + self._move_to_recent_locked(state) + + def snapshot(self) -> Dict[str, Any]: + with self.lock: + active = [state.to_summary() for state in self.active_requests.values()] + recent = [state.to_summary() for state in list(self.recent_requests)] + recent_limit = self.recent_limit + active.sort(key=lambda item: item["submit_ts"]) + return { + "active_count": len(active), + "recent_count": len(recent), + "recent_limit": recent_limit, + "active_requests": active, + "recent_requests": recent, + } + + def collect_summaries(self, request_ids: Sequence[str]) -> list[Dict[str, Any]]: + requested = set(request_ids) + results: list[Dict[str, Any]] = [] + with self.lock: + for state in self.active_requests.values(): + if state.request_id in requested: + results.append(state.to_summary()) + existing_ids = {item["request_id"] for item in results} + for state in self.recent_requests: + if state.request_id in requested and state.request_id not in existing_ids: + results.append(state.to_summary()) + results.sort(key=lambda item: item["request_id"]) + return results + + def has_active(self, request_id: str) -> bool: + with self.lock: + return request_id in self.active_requests + + +@dataclass +class SchedulerPendingJob: + request_id: str + state: T2SRequestState + done_event: threading.Event + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + enqueue_time: float + speed_factor: float + sample_steps: int + media_type: str + super_sampling: bool = False + admission_wait_ms: float = 0.0 + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + prepare_wall_ms: float = 0.0 + prepare_profile_total_ms: float = 0.0 + first_schedule_time: float | None = None + prefill_ms: float = 0.0 + merge_ms: float = 0.0 + decode_ms: float = 0.0 + finalize_wait_ms: float = 0.0 + synth_ms: float = 0.0 + pack_ms: float = 0.0 + decode_steps: int = 0 + result_ready_time: float | None = None + result: dict | None = None + sample_rate: int | None = None + audio_data: np.ndarray | None = None + error: str | None = None + engine_request_id: str | None = None + + +class SchedulerJobRegistry: + def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None: + self._lock = lock + self._job_map: Dict[str, SchedulerPendingJob] = {} + self._total_submitted = 0 + self._total_finished = 0 + + def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None: + with self._lock: + if keep_job: + self._job_map[job.request_id] = job + self._total_submitted += 1 + + def get(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.get(request_id) + + def pop(self, request_id: str) -> SchedulerPendingJob | None: + with self._lock: + return self._job_map.pop(request_id, None) + + def remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + + def mark_finished(self) -> None: + with self._lock: + self._total_finished += 1 + + def mark_finished_and_remove(self, request_id: str) -> None: + with self._lock: + self._job_map.pop(request_id, None) + self._total_finished += 1 + + def is_empty(self) -> bool: + with self._lock: + return not self._job_map + + def submitted_count(self) -> int: + with self._lock: + return int(self._total_submitted) + + def finished_count(self) -> int: + with self._lock: + return int(self._total_finished) + + def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]: + with self._lock: + request_ids = list(self._job_map.keys()) + return { + "job_count": int(len(request_ids)), + "request_ids": request_ids[: max(0, int(max_request_ids))], + "total_submitted": int(self._total_submitted), + "total_finished": int(self._total_finished), + } diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py new file mode 100644 index 00000000..600e1e83 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Deque, Dict, List, Optional, Sequence + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import SchedulerPendingJob + + +class EngineTaskQueueOwner: + def __init__(self, completion_key: str = "total_completed") -> None: + self.condition = threading.Condition() + self.queue: Deque[Any] = deque() + self.total_submitted = 0 + self.total_completed = 0 + self.peak_waiting = 0 + self.completion_key = str(completion_key) + + def enqueue(self, item: Any) -> None: + with self.condition: + self.queue.append(item) + self.total_submitted += 1 + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def enqueue_many(self, items: Sequence[Any]) -> None: + if not items: + return + with self.condition: + for item in items: + self.queue.append(item) + self.total_submitted += len(items) + self.peak_waiting = max(self.peak_waiting, len(self.queue)) + self.condition.notify_all() + + def pop_left(self) -> Any | None: + with self.condition: + if not self.queue: + return None + return self.queue.popleft() + + def pop_left_many(self, max_items: int) -> List[Any]: + limit = max(1, int(max_items)) + with self.condition: + if not self.queue: + return [] + selected: List[Any] = [] + while self.queue and len(selected) < limit: + selected.append(self.queue.popleft()) + return selected + + def mark_completed(self, count: int = 1, *, notify: bool = False) -> None: + if count <= 0: + return + with self.condition: + self.total_completed += int(count) + if notify: + self.condition.notify_all() + + def has_items(self) -> bool: + with self.condition: + return bool(self.queue) + + def waiting_count(self) -> int: + with self.condition: + return int(len(self.queue)) + + def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + with self.condition: + waiting_items = list(self.queue)[: max(0, int(max_request_ids))] + snapshot = { + "waiting_count": int(len(self.queue)), + "waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items], + "peak_waiting": int(self.peak_waiting), + "total_submitted": int(self.total_submitted), + self.completion_key: int(self.total_completed), + } + if extra: + snapshot.update(dict(extra)) + return snapshot + + def peek_oldest_age_ms(self, timestamp_attr: str) -> float: + with self.condition: + if not self.queue: + return 0.0 + enqueue_time = float(getattr(self.queue[0], timestamp_attr)) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def is_drained(self) -> bool: + with self.condition: + return not self.queue and self.total_submitted == self.total_completed + + def take_finalize_batch( + self, + *, + finalize_mode: str, + batch_max_items: int, + batch_wait_s: float, + use_vocoder: bool, + ) -> List[SchedulerFinalizeTask]: + with self.condition: + if not self.queue: + return [] + selected_tasks = [self.queue.popleft()] + if finalize_mode == "sync" or use_vocoder: + return selected_tasks + if batch_max_items <= 1: + return selected_tasks + first_task = selected_tasks[0] + oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time) + if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s: + self.queue.appendleft(first_task) + return [] + while len(selected_tasks) < batch_max_items: + if not self.queue: + break + matched_index = None + for index, task in enumerate(self.queue): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is None: + break + selected_tasks.append(self.queue[matched_index]) + del self.queue[matched_index] + return selected_tasks + + +@dataclass +class EngineDecodeRuntimeState: + pending_jobs: int = 0 + pending_request_ids: List[str] = field(default_factory=list) + active_request_count: int = 0 + active_request_ids: List[str] = field(default_factory=list) + prefill_done: bool = False + decode_step_index_max: int = 0 + total_cycles: int = 0 + prefill_cycles: int = 0 + step_cycles: int = 0 + has_work: bool = False + last_event: str = "init" + updated_at: float = 0.0 + + +class EngineDecodeRuntimeOwner: + def __init__( + self, + *, + get_decode_runtime_counters: Callable[[], Dict[str, int]], + get_micro_batch_wait_s: Callable[[], float], + ) -> None: + self.get_decode_runtime_counters = get_decode_runtime_counters + self.get_micro_batch_wait_s = get_micro_batch_wait_s + self.condition = threading.Condition() + self.pending_jobs: Deque[SchedulerPendingJob] = deque() + self.active_batch: T2SActiveBatch | None = None + self.state_lock = threading.Lock() + self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter()) + + @staticmethod + def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]: + if active_batch is None: + return {} + decode_step_index_max = 0 + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0: + decode_step_index_max = int(active_batch.step_indices.max().item()) + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": int(decode_step_index_max), + } + + def snapshot_pending_queue_state(self) -> Dict[str, Any]: + with self.condition: + return { + "pending_jobs": int(len(self.pending_jobs)), + "pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]], + } + + def enqueue_pending_job(self, job: SchedulerPendingJob) -> None: + with self.condition: + self.pending_jobs.append(job) + self.condition.notify_all() + self.refresh_state("engine_decode_pending_enqueue") + + def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s(): + return [] + pending_jobs = list(self.pending_jobs) + self.pending_jobs.clear() + self.refresh_state("engine_decode_pending_dequeue") + return pending_jobs + + def pending_age_ms(self) -> float: + with self.condition: + if not self.pending_jobs: + return 0.0 + enqueue_time = float(self.pending_jobs[0].enqueue_time) + return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0) + + def has_pending_jobs(self) -> bool: + with self.condition: + return bool(self.pending_jobs) + + def get_active_batch(self) -> T2SActiveBatch | None: + return self.active_batch + + def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None: + self.active_batch = active_batch + + def active_batch_summary(self) -> Dict[str, Any]: + return self.summarize_active_batch(self.active_batch) + + def refresh_state(self, last_event: str) -> None: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(active_batch_summary.get("request_count", 0)) + self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32] + self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False)) + self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0)) + self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0)) + self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0)) + self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0)) + self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0)) + self.state.last_event = str(last_event) + self.state.updated_at = float(time.perf_counter()) + + def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None: + if not snapshot: + return + pending_state = self.snapshot_pending_queue_state() + with self.state_lock: + self.state.pending_jobs = int(pending_state.get("pending_jobs", 0)) + self.state.pending_request_ids = list(pending_state.get("pending_request_ids", [])) + self.state.active_request_count = int(snapshot.get("active_request_count", 0)) + self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32] + self.state.prefill_done = bool(snapshot.get("prefill_done", False)) + self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0)) + self.state.total_cycles = int(snapshot.get("total_cycles", 0)) + self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0)) + self.state.step_cycles = int(snapshot.get("step_cycles", 0)) + self.state.has_work = bool( + pending_state.get("pending_jobs", 0) + or snapshot.get("active_request_count", 0) + or snapshot.get("has_work", False) + ) + self.state.last_event = str(snapshot.get("last_event", "unknown")) + self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter())) + + def snapshot_state(self) -> Dict[str, Any]: + pending_state = self.snapshot_pending_queue_state() + active_batch_summary = self.active_batch_summary() + worker_decode_counters = self.get_decode_runtime_counters() + with self.state_lock: + return { + "pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)), + "pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)), + "active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)), + "active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)), + "prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)), + "decode_step_index_max": int(active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max)), + "total_cycles": int(worker_decode_counters.get("total_cycles", 0)), + "prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)), + "step_cycles": int(worker_decode_counters.get("step_cycles", 0)), + "has_work": bool( + pending_state.get("pending_jobs", 0) + or active_batch_summary.get("request_count", self.state.active_request_count) + or self.state.has_work + ), + "last_event": str(self.state.last_event), + "updated_at": float(self.state.updated_at), + } + + +@dataclass +class SchedulerFinalizeTask: + request_id: str + item: T2SFinishedItem + enqueued_time: float + + +@dataclass +class EngineDispatchTask: + request_id: str + state: T2SRequestState + speed_factor: float + sample_steps: int + media_type: str + super_sampling: bool + prepare_wall_ms: float + prepare_profile_total_ms: float + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + timeout_sec: float | None + enqueue_time: float + worker_job: SchedulerPendingJob | None = None + engine_policy_wait_ms: float = 0.0 + engine_dispatch_wait_ms: float = 0.0 + engine_policy_snapshot: Dict[str, Any] | None = None + error: str | None = None + + +@dataclass +class EngineGpuPrepareTask: + request_id: str + cpu_stage: PreparedCpuStage + done_loop: asyncio.AbstractEventLoop | None + done_future: asyncio.Future | None + engine_request_id: str | None + enqueue_time: float + phase: str = "audio" + audio_enqueue_time: float = 0.0 + audio_start_time: float = 0.0 + audio_end_time: float = 0.0 + text_enqueue_time: float = 0.0 + text_start_time: float = 0.0 + text_end_time: float = 0.0 + ref_spec_enqueue_time: float = 0.0 + ref_spec_start_time: float = 0.0 + ref_spec_end_time: float = 0.0 + audio_queue_wait_ms: float = 0.0 + text_queue_wait_ms: float = 0.0 + ref_spec_queue_wait_ms: float = 0.0 + admission_wait_ms: float = 0.0 + phase_one: Dict[str, Any] | None = None + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None + state_result: T2SRequestState | None = None + cancelled: bool = False + error: str | None = None + + +@dataclass +class EngineFinalizeQueueState: + waiting_count: int + waiting_request_ids: List[str] + peak_waiting: int + total_submitted: int + total_completed: int + + +@dataclass +class RuntimeStateCallbacks: + update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None + complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None + fail: Callable[[str, str], None] | None = None + decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py new file mode 100644 index 00000000..ac1adac5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_components.py @@ -0,0 +1,63 @@ +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_models import ( + DirectTTSExecution, + NormalizedEngineRequest, + RuntimeControlCallbacks, + SchedulerDebugExecution, + SchedulerSubmitExecution, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_policy import ( + EngineArbiterConfig, + EngineArbiterState, + EnginePolicyArbiterController, + EnginePolicyConfig, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import ( + DefaultReferenceState, + EngineRequestRegistry, + EngineRequestState, + EngineStatus, + ModelRegistry, + ModelRegistryState, + ReferenceRegistry, + SchedulerJobRegistry, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_component_runtime import ( + EngineDecodeRuntimeOwner, + EngineDecodeRuntimeState, + EngineDispatchTask, + EngineFinalizeQueueState, + EngineGpuPrepareTask, + EngineTaskQueueOwner, + RuntimeStateCallbacks, + SchedulerFinalizeTask, +) + +__all__ = [ + "DefaultReferenceState", + "DirectTTSExecution", + "EngineArbiterConfig", + "EngineArbiterState", + "EngineDecodeRuntimeOwner", + "EngineDecodeRuntimeState", + "EngineDispatchTask", + "EngineFinalizeQueueState", + "EngineGpuPrepareTask", + "EnginePolicyArbiterController", + "EnginePolicyConfig", + "EngineRequestRegistry", + "EngineRequestState", + "EngineStatus", + "EngineTaskQueueOwner", + "ModelRegistry", + "ModelRegistryState", + "NormalizedEngineRequest", + "ReferenceRegistry", + "RuntimeControlCallbacks", + "RuntimeStateCallbacks", + "SchedulerDebugExecution", + "SchedulerFinalizeTask", + "SchedulerJobRegistry", + "SchedulerPendingJob", + "SchedulerSubmitExecution", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py new file mode 100644 index 00000000..d60a3bb8 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py @@ -0,0 +1,9 @@ +from GPT_SoVITS.TTS_infer_pack.unified_engine_api_delegates import EngineApiDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_delegates import EngineBridgeDelegates +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime_delegates import EngineRuntimeDelegates + +__all__ = [ + "EngineApiDelegates", + "EngineBridgeDelegates", + "EngineRuntimeDelegates", +] diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py new file mode 100644 index 00000000..0c73616f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineTaskQueueOwner +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageOrchestrator: + def __init__( + self, + *, + executor: EngineStageExecutor, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.executor = executor + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None + self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None + self._wait_arbiter: Callable[[], None] | None = None + + def bind_arbiter( + self, + *, + notify_arbiter: Callable[[], None], + select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]], + mark_arbiter_tick: Callable[[str, str, bool], None], + wait_arbiter: Callable[[], None], + ) -> None: + self.executor.bind_notify_arbiter(notify_arbiter) + self._select_stage = select_stage + self._mark_arbiter_tick = mark_arbiter_tick + self._wait_arbiter = wait_arbiter + + def peek_queue_age_ms(self, queue_name: str) -> float: + if queue_name == "prepare": + return max( + self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"), + self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"), + ) + if queue_name == "prepare_audio": + return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_text": + return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "prepare_ref_spec": + return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time") + if queue_name == "finalize": + return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time") + if queue_name == "decode_runtime_pending": + return self.decode_runtime_owner.pending_age_ms() + return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time") + + def has_pending_work(self) -> bool: + if self.scheduler_worker.is_engine_decode_control_enabled(): + if self.decode_runtime_owner.has_pending_jobs(): + return True + if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get( + "active_request_count", 0 + ) > 0: + return True + if self.prepare_queue_owner.has_items(): + return True + if self.prepare_text_queue_owner.has_items(): + return True + if self.prepare_ref_spec_queue_owner.has_items(): + return True + if self.finalize_queue_owner.has_items(): + return True + return self.dispatch_queue_owner.has_items() + + def run_engine_arbiter_loop(self) -> None: + if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None: + raise RuntimeError("arbiter callbacks are not bound") + while True: + if not self.has_pending_work(): + self._mark_arbiter_tick("idle", "no_pending_work", True) + self._wait_arbiter() + continue + stage, reason, policy_snapshot, worker_state = self._select_stage() + policy_allowed = bool(policy_snapshot.get("allowed", True)) + executed = False + if stage == "prepare": + executed = self.executor.run_engine_prepare_once() + elif stage == "prepare_audio": + executed = self.executor.run_engine_prepare_audio_once() + elif stage == "prepare_text": + executed = self.executor.run_engine_prepare_text_once() + elif stage == "prepare_ref_spec": + executed = self.executor.run_engine_prepare_ref_spec_once() + elif stage == "finalize": + executed = self.executor.run_engine_finalize_once() + elif stage == "decode_dispatch": + executed = self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) + elif stage == "decode_runtime": + executed = self.executor.run_engine_decode_runtime_once() + if not executed: + self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed) + self._wait_arbiter() + continue + self._mark_arbiter_tick(stage, reason, policy_allowed) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py new file mode 100644 index 00000000..fbe88b85 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_public.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution + + +class EnginePublicInterface: + PUBLIC_API_METHODS = ( + "run_direct_tts_async", + "run_scheduler_submit", + "run_scheduler_debug", + "get_runtime_state", + "set_refer_audio", + "set_gpt_weights", + "set_sovits_weights", + "handle_control", + ) + + async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution: + return await self.api_facade.run_direct_tts_async(req) + + async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution: + return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed) + + async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution: + return await self.api_facade.run_scheduler_submit(payload) + + def get_runtime_state(self) -> dict: + return self.runtime_facade.get_runtime_state() + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + return self.runtime_facade.set_refer_audio(refer_audio_path) + + def set_gpt_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_gpt_weights(weights_path) + + def set_sovits_weights(self, weights_path: str) -> dict: + return self.runtime_facade.set_sovits_weights(weights_path) + + def handle_control(self, command: str) -> None: + self.runtime_facade.handle_control(command) + + +class EngineCompatInterface: + COMPAT_API_METHODS = ( + "run_direct_tts", + "get_scheduler_state", + ) + + def run_direct_tts(self, req: dict) -> DirectTTSExecution: + return self.api_facade.run_direct_tts(req) + + def get_scheduler_state(self) -> dict: + return self.runtime_facade.get_scheduler_state() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py new file mode 100644 index 00000000..70212a4d --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import os +import signal +import sys +from typing import Any, Dict, Optional + + +class EngineRuntimeFacade: + def __init__(self, owner: Any) -> None: + self.owner = owner + + @property + def tts(self): + return self.owner.tts + + @property + def reference_registry(self): + return self.owner.reference_registry + + @property + def model_registry(self): + return self.owner.model_registry + + @property + def scheduler_worker(self): + return self.owner.scheduler_worker + + @property + def engine_decode_runtime_owner(self): + return self.owner.engine_decode_runtime_owner + + @property + def engine_policy_arbiter(self): + return self.owner.engine_policy_arbiter + + @property + def management_lock(self): + return self.owner.management_lock + + @property + def direct_tts_lock(self): + return self.owner.direct_tts_lock + + @property + def control_callbacks(self): + return self.owner.control_callbacks + + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + if component is None or not hasattr(component, "snapshot"): + return None + try: + return dict(component.snapshot()) + except Exception: + return None + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + counters = self._build_stage_counters(request_registry, worker_state) + bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None)) + ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None)) + text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None)) + + return { + **counters, + "engine_drained": bool(self.owner._is_engine_drained()), + "admission_config": { + "decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)), + "finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)), + }, + "engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state), + "engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(), + "engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(), + "engine_job_registry": self.owner._snapshot_engine_job_registry(), + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": self.owner._snapshot_engine_prepare_state(), + "engine_finalize_state": self.owner._snapshot_engine_finalize_state(), + "engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(), + "active_batch": dict(worker_state.get("active_batch") or {}), + "prepare_state": dict(worker_state.get("prepare_state") or {}), + "bert_batch_worker_state": bert_worker_state, + "ref_semantic_worker_state": ref_semantic_worker_state, + "text_preprocessor_state": text_preprocessor_state, + } + + def get_scheduler_state(self) -> dict: + return self.scheduler_worker.snapshot() + + def get_runtime_state(self) -> dict: + model_state = self.model_registry.snapshot() + default_ref = self.reference_registry.get_default() + scheduler_state = self.get_scheduler_state() + request_registry = self.owner._snapshot_request_registry() + engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state) + engine_arbiter_state = self.owner._snapshot_engine_arbiter_state() + engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state() + engine_job_registry = self.owner._snapshot_engine_job_registry() + engine_prepare_state = self.owner._snapshot_engine_prepare_state() + engine_finalize_state = self.owner._snapshot_engine_finalize_state() + engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state() + engine_drained = self.owner._is_engine_drained() + return { + "message": "success", + "default_reference": { + "ref_audio_path": default_ref.ref_audio_path, + "updated_at": default_ref.updated_at, + }, + "model_registry": { + "generation": model_state.generation, + "t2s_generation": model_state.t2s_generation, + "vits_generation": model_state.vits_generation, + "t2s_weights_path": model_state.t2s_weights_path, + "vits_weights_path": model_state.vits_weights_path, + "updated_at": model_state.updated_at, + }, + "worker_state": scheduler_state, + "engine_policy": engine_policy, + "engine_arbiter_state": engine_arbiter_state, + "engine_decode_runtime_state": engine_decode_runtime_state, + "engine_job_registry": engine_job_registry, + "engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(), + "engine_prepare_state": engine_prepare_state, + "engine_finalize_state": engine_finalize_state, + "engine_dispatcher_state": engine_dispatcher_state, + "engine_drained": bool(engine_drained), + "request_registry": request_registry, + "stage_summary": self._build_stage_summary(request_registry, scheduler_state), + } + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec): + raise TimeoutError("scheduler worker did not drain before model reload") + + def set_refer_audio(self, refer_audio_path: str | None) -> dict: + if refer_audio_path in [None, ""]: + state = self.reference_registry.clear() + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + if not os.path.exists(str(refer_audio_path)): + raise FileNotFoundError(f"{refer_audio_path} not exists") + with self.management_lock: + with self.direct_tts_lock: + self.tts.set_ref_audio(str(refer_audio_path)) + state = self.reference_registry.set_default(str(refer_audio_path)) + return {"message": "success", "default_ref_audio_path": state.ref_audio_path} + + def set_gpt_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("gpt weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_t2s_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_t2s_reload(str(weights_path)) + return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation} + + def set_sovits_weights(self, weights_path: str) -> dict: + if weights_path in ["", None]: + raise ValueError("sovits weight path is required") + with self.management_lock: + self._wait_for_safe_reload() + with self.direct_tts_lock: + self.tts.init_vits_weights(weights_path) + self.tts.refresh_runtime_components() + state = self.model_registry.mark_vits_reload(str(weights_path)) + return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation} + + def handle_control(self, command: str) -> None: + if command == "restart": + if self.control_callbacks.restart is None: + os.execl(sys.executable, sys.executable, *sys.argv) + self.control_callbacks.restart() + return + if command == "exit": + if self.control_callbacks.exit is None: + os.kill(os.getpid(), signal.SIGTERM) + return + self.control_callbacks.exit() + return + raise ValueError(f"unsupported command: {command}") diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py new file mode 100644 index 00000000..96153196 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_runtime_delegates.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Dict + +from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade + + +class EngineRuntimeDelegates: + @staticmethod + def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None: + return EngineRuntimeFacade._safe_component_snapshot(component) + + def _build_stage_counters( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_counters(request_registry, worker_state) + + def _build_engine_policy_snapshot( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state) + + async def _wait_for_engine_policy_admission( + self, + *, + request_id: str | None, + timeout_sec: float | None, + ) -> tuple[float, Dict[str, Any]]: + return await self.engine_policy_arbiter.wait_for_policy_admission( + request_id=request_id, + timeout_sec=timeout_sec, + ) + + def _build_stage_summary( + self, + request_registry: Dict[str, Any], + worker_state: Dict[str, Any], + ) -> Dict[str, Any]: + return self.runtime_facade._build_stage_summary(request_registry, worker_state) + + def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None: + self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py new file mode 100644 index 00000000..27ed3bf5 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import asyncio +from typing import Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineDecodeRuntimeOwner, + EngineDispatchTask, + EngineTaskQueueOwner, + SchedulerFinalizeTask, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_orchestration import EngineStageOrchestrator +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageCoordinator: + def __init__( + self, + *, + tts: TTS, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + fail_request_state: Callable[[str, str], None], + get_engine_job: Callable[[str], SchedulerPendingJob | None], + register_engine_job: Callable[[SchedulerPendingJob], None], + fail_engine_jobs: Callable[[List[str], str], None], + complete_engine_job: Callable[..., None], + add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None], + add_engine_merge_time: Callable[[List[str], float], None], + add_engine_decode_time: Callable[[List[str], float], None], + enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None], + snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.executor = EngineStageExecutor( + tts=tts, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + update_request_state=update_request_state, + merge_request_state_profile=merge_request_state_profile, + fail_request_state=fail_request_state, + get_engine_job=get_engine_job, + register_engine_job=register_engine_job, + fail_engine_jobs=fail_engine_jobs, + complete_engine_job=complete_engine_job, + add_engine_prefill_time=add_engine_prefill_time, + add_engine_merge_time=add_engine_merge_time, + add_engine_decode_time=add_engine_decode_time, + enqueue_engine_finished_items=enqueue_engine_finished_items, + snapshot_engine_dispatch_state=snapshot_engine_dispatch_state, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) + self.orchestrator = EngineStageOrchestrator( + executor=self.executor, + scheduler_worker=scheduler_worker, + prepare_queue_owner=prepare_queue_owner, + prepare_text_queue_owner=prepare_text_queue_owner, + prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner, + finalize_queue_owner=finalize_queue_owner, + dispatch_queue_owner=dispatch_queue_owner, + decode_runtime_owner=decode_runtime_owner, + snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state, + ) + + def bind_arbiter( + self, + *, + notify_arbiter: Callable[[], None], + select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]], + mark_arbiter_tick: Callable[[str, str, bool], None], + wait_arbiter: Callable[[], None], + ) -> None: + self.orchestrator.bind_arbiter( + notify_arbiter=notify_arbiter, + select_stage=select_stage, + mark_arbiter_tick=mark_arbiter_tick, + wait_arbiter=wait_arbiter, + ) + + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + return await self.executor.prepare_state_via_engine_gpu_queue( + spec=spec, + prepare_submit_at=prepare_submit_at, + engine_request_id=engine_request_id, + ) + + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + self.executor.enqueue_worker_finished_for_finalize(tasks) + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + return self.executor.take_engine_finalize_batch_nonblocking() + + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + super_sampling: bool, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + return await self.executor.enqueue_prepared_state_for_dispatch( + state=state, + speed_factor=speed_factor, + sample_steps=sample_steps, + media_type=media_type, + super_sampling=super_sampling, + prepare_wall_ms=prepare_wall_ms, + prepare_profile_total_ms=prepare_profile_total_ms, + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id, + timeout_sec=timeout_sec, + ) + + def peek_queue_age_ms(self, queue_name: str) -> float: + return self.orchestrator.peek_queue_age_ms(queue_name) + + def has_pending_work(self) -> bool: + return self.orchestrator.has_pending_work() + + def run_engine_prepare_once(self) -> bool: + return self.executor.run_engine_prepare_once() + + def run_engine_prepare_audio_once(self) -> bool: + return self.executor.run_engine_prepare_audio_once() + + def run_engine_prepare_text_once(self) -> bool: + return self.executor.run_engine_prepare_text_once() + + def run_engine_prepare_ref_spec_once(self) -> bool: + return self.executor.run_engine_prepare_ref_spec_once() + + def run_engine_finalize_once(self) -> bool: + return self.executor.run_engine_finalize_once() + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool: + return self.executor.run_engine_dispatch_once(policy_snapshot, worker_state) + + def run_engine_decode_runtime_once(self) -> bool: + return self.executor.run_engine_decode_runtime_once() + + def run_engine_arbiter_loop(self) -> None: + self.orchestrator.run_engine_arbiter_loop() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py new file mode 100644 index 00000000..d3a7a8cf --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py @@ -0,0 +1,40 @@ +from __future__ import annotations + + +class EngineDecodeStageMixin: + def run_engine_decode_runtime_once(self) -> bool: + if not self.scheduler_worker.is_engine_decode_control_enabled(): + return False + runtime_state = self.snapshot_engine_decode_runtime_state() + pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking( + wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0 + ) + result = self.scheduler_worker.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=self.decode_runtime_owner.get_active_batch(), + external_bookkeeping=True, + ) + prefill_phase = dict(result.get("prefill_phase") or {}) + if prefill_phase.get("error"): + self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error"))) + else: + prefill_jobs = list(prefill_phase.get("pending_jobs") or []) + self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0))) + self.add_engine_merge_time( + [] if result.get("active_batch") is None else list(result["active_batch"].request_ids), + float(prefill_phase.get("merge_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or [])) + decode_phase = dict(result.get("decode_phase") or {}) + if decode_phase.get("error"): + self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error"))) + else: + self.add_engine_decode_time( + list(decode_phase.get("request_ids") or []), + float(decode_phase.get("decode_elapsed_s", 0.0)), + ) + self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or [])) + self.decode_runtime_owner.set_active_batch(result.get("active_batch")) + if result.get("executed", False): + self.decode_runtime_owner.refresh_state("engine_decode_cycle") + return bool(result.get("executed", False)) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py new file mode 100644 index 00000000..f6a249fa --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Dict + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask + + +class EngineDispatchStageMixin: + async def enqueue_prepared_state_for_dispatch( + self, + *, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + super_sampling: bool, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None, + done_future: asyncio.Future | None, + engine_request_id: str | None, + timeout_sec: float | None, + ) -> EngineDispatchTask: + if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + error = RuntimeError("ref_spec async stage failed before dispatch") + self.fail_request_state(engine_request_id or state.request_id, str(error)) + raise error + task = EngineDispatchTask( + request_id=state.request_id, + state=state, + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + super_sampling=bool(super_sampling), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + done_loop=done_loop, + done_future=done_future, + engine_request_id=engine_request_id or state.request_id, + timeout_sec=timeout_sec, + enqueue_time=time.perf_counter(), + ) + self.dispatch_queue_owner.enqueue(task) + self.notify_arbiter() + self.merge_request_state_profile( + task.engine_request_id or task.request_id, + { + "engine_dispatch_queue_depth_on_enqueue": int( + self.snapshot_engine_dispatch_state()["waiting_count"] + ), + }, + ) + return task + + def run_engine_dispatch_once(self, policy_snapshot: Dict[str, object], worker_state: Dict[str, object]) -> bool: + if not bool(policy_snapshot.get("allowed", True)): + return False + dispatch_task = self.dispatch_queue_owner.pop_left() + if dispatch_task is None: + return False + dispatched_at = time.perf_counter() + dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0) + dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms) + dispatch_task.engine_policy_snapshot = dict(policy_snapshot) + try: + worker_job = self.scheduler_worker.submit( + state=dispatch_task.state, + speed_factor=dispatch_task.speed_factor, + sample_steps=dispatch_task.sample_steps, + media_type=dispatch_task.media_type, + super_sampling=dispatch_task.super_sampling, + prepare_wall_ms=dispatch_task.prepare_wall_ms, + prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms, + done_loop=dispatch_task.done_loop, + done_future=dispatch_task.done_future, + engine_request_id=dispatch_task.engine_request_id, + timeout_sec=dispatch_task.timeout_sec, + skip_capacity_wait=True, + admission_wait_ms_override=0.0, + admission_snapshot_override=dict(worker_state), + engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms, + engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms, + enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(), + ) + dispatch_task.worker_job = worker_job + self.register_engine_job(worker_job) + if self.scheduler_worker.is_engine_decode_control_enabled(): + self.decode_runtime_owner.enqueue_pending_job(worker_job) + self.notify_arbiter() + self.dispatch_queue_owner.mark_completed(1) + return True + except Exception as exc: + dispatch_task.error = str(exc) + self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc)) + self._notify_dispatch_error(dispatch_task, exc) + return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py new file mode 100644 index 00000000..6d06f0c6 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Optional + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import ( + EngineDecodeRuntimeOwner, + EngineTaskQueueOwner, + SchedulerFinalizeTask, + SchedulerPendingJob, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_decode import EngineDecodeStageMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_dispatch import EngineDispatchStageMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_finalize import EngineFinalizeStageMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_futures import EngineStageFutureMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_prepare import EnginePrepareStageMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker + + +class EngineStageExecutor( + EngineStageFutureMixin, + EnginePrepareStageMixin, + EngineFinalizeStageMixin, + EngineDispatchStageMixin, + EngineDecodeStageMixin, +): + def __init__( + self, + *, + tts: TTS, + scheduler_worker: UnifiedSchedulerWorker, + prepare_queue_owner: EngineTaskQueueOwner, + prepare_text_queue_owner: EngineTaskQueueOwner, + prepare_ref_spec_queue_owner: EngineTaskQueueOwner, + finalize_queue_owner: EngineTaskQueueOwner, + dispatch_queue_owner: EngineTaskQueueOwner, + decode_runtime_owner: EngineDecodeRuntimeOwner, + update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None], + merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None], + fail_request_state: Callable[[str, str], None], + get_engine_job: Callable[[str], SchedulerPendingJob | None], + register_engine_job: Callable[[SchedulerPendingJob], None], + fail_engine_jobs: Callable[[List[str], str], None], + complete_engine_job: Callable[..., None], + add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None], + add_engine_merge_time: Callable[[List[str], float], None], + add_engine_decode_time: Callable[[List[str], float], None], + enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None], + snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]], + snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]], + ) -> None: + self.tts = tts + self.scheduler_worker = scheduler_worker + self.prepare_queue_owner = prepare_queue_owner + self.prepare_text_queue_owner = prepare_text_queue_owner + self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner + self.finalize_queue_owner = finalize_queue_owner + self.dispatch_queue_owner = dispatch_queue_owner + self.decode_runtime_owner = decode_runtime_owner + self.update_request_state = update_request_state + self.merge_request_state_profile = merge_request_state_profile + self.fail_request_state = fail_request_state + self.get_engine_job = get_engine_job + self.register_engine_job = register_engine_job + self.fail_engine_jobs = fail_engine_jobs + self.complete_engine_job = complete_engine_job + self.add_engine_prefill_time = add_engine_prefill_time + self.add_engine_merge_time = add_engine_merge_time + self.add_engine_decode_time = add_engine_decode_time + self.enqueue_engine_finished_items = enqueue_engine_finished_items + self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state + self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state + self._notify_arbiter: Callable[[], None] | None = None diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py new file mode 100644 index 00000000..4b61993e --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import time +from typing import List + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class EngineFinalizeStageMixin: + def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + self.update_request_state( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": task.item.finish_reason, + "semantic_len": int(task.item.semantic_tokens.shape[0]), + "finish_idx": int(task.item.finish_idx), + }, + ) + self.finalize_queue_owner.enqueue_many(tasks) + self.notify_arbiter() + + def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]: + finalize_policy = self.scheduler_worker.get_finalize_batch_policy() + return self.finalize_queue_owner.take_finalize_batch( + finalize_mode=str(finalize_policy.get("finalize_mode", "async")), + batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)), + batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)), + use_vocoder=bool(self.tts.configs.use_vocoder), + ) + + def run_engine_finalize_once(self) -> bool: + tasks = self.take_engine_finalize_batch_nonblocking() + if not tasks: + return False + ready_tasks: List[SchedulerFinalizeTask] = [] + failed_tasks: List[SchedulerFinalizeTask] = [] + deferred_tasks: List[SchedulerFinalizeTask] = [] + for task in tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0: + failed_tasks.append(task) + continue + if job.state.refer_spec is None: + deferred_tasks.append(task) + self.merge_request_state_profile( + job.engine_request_id or job.request_id, + { + "engine_finalize_ref_spec_blocked": 1.0, + }, + ) + continue + ready_tasks.append(task) + if deferred_tasks: + self.finalize_queue_owner.enqueue_many(deferred_tasks) + if failed_tasks: + self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed") + if not ready_tasks: + self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True) + return False + self.scheduler_worker.begin_finalize_execution(len(ready_tasks)) + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + for task in ready_tasks: + job = self.get_engine_job(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return False + now = time.perf_counter() + for task in ready_tasks: + job = self.get_engine_job(task.request_id) + if job is not None: + job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0) + for job, item in jobs_and_items: + self.update_request_state( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items) + for job, _ in jobs_and_items: + job.synth_ms += float(synth_ms) + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc)) + finally: + self.scheduler_worker.end_finalize_execution(len(ready_tasks)) + self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True) + return True diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py new file mode 100644 index 00000000..43fdd0bf --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import asyncio +from typing import Callable + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineGpuPrepareTask + + +class EngineStageFutureMixin: + def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None: + self._notify_arbiter = notify_arbiter + + def notify_arbiter(self) -> None: + if self._notify_arbiter is not None: + self._notify_arbiter() + + @staticmethod + def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None: + if future.done(): + return + future.set_exception(error) + + @staticmethod + def _resolve_prepare_future( + future: asyncio.Future, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if future.done(): + return + future.set_result(payload) + + def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error) + except RuntimeError: + pass + + def _notify_prepare_result( + self, + task: EngineGpuPrepareTask, + payload: tuple[T2SRequestState, float, float], + ) -> None: + if task.done_loop is None or task.done_future is None: + return + try: + task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload) + except RuntimeError: + pass diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py new file mode 100644 index 00000000..1ea7c45b --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import asyncio +import os +import time +from typing import Any + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepareTask, EngineStatus + + +class EnginePrepareStageMixin: + def _prepare_waiting_total(self) -> int: + return ( + int(self.prepare_queue_owner.waiting_count()) + + int(self.prepare_text_queue_owner.waiting_count()) + + int(self.prepare_ref_spec_queue_owner.waiting_count()) + ) + + async def _wait_prepare_queue_admission(self) -> float: + soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0"))) + if soft_max <= 0: + return 0.0 + poll_s = max( + 0.0005, + float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0, + ) + wait_start = time.perf_counter() + while self._prepare_waiting_total() >= soft_max: + await asyncio.sleep(poll_s) + return max(0.0, (time.perf_counter() - wait_start) * 1000.0) + + async def prepare_state_via_engine_gpu_queue( + self, + *, + spec: Any, + prepare_submit_at: float, + engine_request_id: str | None, + ) -> tuple[T2SRequestState, float, float]: + prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission() + cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + if engine_request_id not in [None, ""]: + self.update_request_state( + str(engine_request_id), + EngineStatus.GPU_PREPARING, + { + "engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms), + "prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms), + "prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms), + "text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms), + "text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms), + }, + ) + loop = asyncio.get_running_loop() + done_future = loop.create_future() + task = EngineGpuPrepareTask( + request_id=spec.request_id, + cpu_stage=cpu_stage, + done_loop=loop, + done_future=done_future, + engine_request_id=engine_request_id or spec.request_id, + enqueue_time=time.perf_counter(), + phase="audio", + audio_enqueue_time=time.perf_counter(), + admission_wait_ms=float(prepare_queue_admission_wait_ms), + ) + self.prepare_queue_owner.enqueue(task) + self.notify_arbiter() + return await done_future + + def _should_chain_prepare_text_after_audio(self) -> bool: + if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}: + return False + if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items(): + return False + decode_runtime_state = self.snapshot_engine_decode_runtime_state() + if bool(decode_runtime_state.get("has_work", False)): + return False + return True + + def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None: + if task.state_result is None or task.ref_spec_result is None: + return + self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms), + "ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)), + "ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)), + "ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)), + "ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)), + "ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)), + "ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)), + "ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)), + }, + ) + + def _mark_ref_spec_async_failed( + self, + task: EngineGpuPrepareTask, + error: Exception, + *, + queue_wait_ms: float, + ) -> None: + task.error = str(error) + task.cancelled = True + if task.state_result is not None: + task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0 + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "ref_spec_async_failed": 1.0, + "engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms), + }, + ) + self.fail_request_state(task.engine_request_id or task.request_id, str(error)) + self.fail_engine_jobs([task.request_id], str(error)) + self.notify_arbiter() + + def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + for task in tasks: + task.audio_start_time = float(now) + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks])) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.audio_end_time = time.perf_counter() + if isinstance(result, Exception): + task.error = str(result) + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + task.audio_queue_wait_ms = float(queue_wait_ms) + task.phase_one = result + task.phase = "text" + task.enqueue_time = time.perf_counter() + task.text_enqueue_time = float(task.enqueue_time) + task.ref_spec_enqueue_time = float(task.enqueue_time) + self.prepare_text_queue_owner.enqueue(task) + self.prepare_ref_spec_queue_owner.enqueue(task) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + completed_count += 1 + self.prepare_queue_owner.mark_completed(completed_count) + if completed_count > 0 and self._should_chain_prepare_text_after_audio(): + self._run_engine_prepare_text_once(min(batch_max_items, completed_count)) + return True + if completed_count > 0: + self.notify_arbiter() + return True + + def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks] + for task in tasks: + task.text_start_time = float(now) + items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None] + batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items)) + completed_count = 0 + for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results): + task.text_end_time = time.perf_counter() + if isinstance(result, Exception): + task.error = str(result) + task.cancelled = True + self.fail_request_state(task.engine_request_id or task.request_id, str(result)) + self._notify_prepare_error(task, result) + completed_count += 1 + continue + task.text_queue_wait_ms = float(queue_wait_ms) + state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases( + task.cpu_stage, + task.phase_one or {}, + result, + extra_profile={ + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), + "engine_prepare_audio_batch_size": float(len(tasks)), + "engine_prepare_text_batch_size": float(len(tasks)), + "engine_prepare_audio_phase_mode": 2.0, + "engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)), + "engine_prepare_text_phase_batch_size": float(len(tasks)), + "engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time), + "engine_prepare_audio_start_ts": float(task.audio_start_time), + "engine_prepare_audio_end_ts": float(task.audio_end_time), + "engine_prepare_text_enqueue_ts": float(task.text_enqueue_time), + "engine_prepare_text_start_ts": float(task.text_start_time), + "engine_prepare_text_end_ts": float(task.text_end_time), + "engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time), + }, + ) + task.state_result = state + self._maybe_apply_ref_spec_to_state(task) + state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks)) + if task.engine_request_id not in [None, ""]: + self.merge_request_state_profile( + str(task.engine_request_id), + { + "engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms), + "engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms), + "engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms), + "engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms), + "engine_gpu_prepare_batch_size": float(len(tasks)), + }, + ) + self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at)) + completed_count += 1 + self.prepare_text_queue_owner.mark_completed(completed_count) + return True + + def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool: + tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items) + if not tasks: + return False + now = time.perf_counter() + runnable_tasks: list[EngineGpuPrepareTask] = [] + queue_wait_ms_list: list[float] = [] + completed_count = 0 + for task in tasks: + if task.cancelled or task.phase_one is None: + completed_count += 1 + continue + task.ref_spec_start_time = float(now) + runnable_tasks.append(task) + queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0)) + if not runnable_tasks: + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + batch_results = asyncio.run( + self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks]) + ) + for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results): + task.ref_spec_end_time = time.perf_counter() + task.ref_spec_queue_wait_ms = float(queue_wait_ms) + if isinstance(result, Exception): + self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms)) + completed_count += 1 + continue + task.ref_spec_result = result + self._maybe_apply_ref_spec_to_state(task) + if task.state_result is not None: + task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms) + task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time) + task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time) + completed_count += 1 + self.prepare_ref_spec_queue_owner.mark_completed(completed_count) + return True + + def run_engine_prepare_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1)) + audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time") + text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time") + if self.prepare_text_queue_owner.has_items() and ( + not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms + ): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_queue_owner.has_items(): + return self._run_engine_prepare_audio_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + if self.prepare_text_queue_owner.has_items(): + return self._run_engine_prepare_text_once(batch_max_items) + if self.prepare_ref_spec_queue_owner.has_items(): + return self._run_engine_prepare_ref_spec_once(batch_max_items) + return False + + def run_engine_prepare_audio_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_text_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) + + def run_engine_prepare_ref_spec_once(self) -> bool: + prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy() + return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1))) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py new file mode 100644 index 00000000..ae46536f --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import os +import threading +from typing import Callable, List + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_completion import WorkerCompletionBridge +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_decode import WorkerDecodeExecutor, WorkerDecodeLegacyShell, WorkerDecodeRuntimeTracker +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_execution import WorkerExecutionMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_finalize import WorkerFinalizeExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_prepare import WorkerPrepareExecutor +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_runtime import WorkerRuntimeBookkeepingMixin +from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_submit import WorkerSubmitLifecycleMixin + + +class UnifiedSchedulerWorker( + WorkerSubmitLifecycleMixin, + WorkerRuntimeBookkeepingMixin, + WorkerExecutionMixin, +): + def __init__( + self, + tts: TTS, + max_steps: int = 1500, + micro_batch_wait_ms: int = 5, + runtime_callbacks: RuntimeStateCallbacks | None = None, + external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ): + self.tts = tts + self.max_steps = int(max_steps) + self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0 + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.condition = threading.Condition() + self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks) + self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps) + self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s) + self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks) + self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change) + self.finalize_executor = WorkerFinalizeExecutor( + tts, + on_state_change=self._notify_worker_state_change, + external_submit=external_finalize_submit, + ) + self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0"))) + self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0"))) + self.engine_decode_control_enabled = ( + str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "1")).strip().lower() in {"1", "true", "yes", "on"} + ) + self.job_registry = SchedulerJobRegistry(self.condition) + self.worker_thread: threading.Thread | None = None + if not self.engine_decode_control_enabled: + self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True) + self.worker_thread.start() + self.finalize_threads = [] + if external_finalize_submit is None: + self.finalize_threads = [ + threading.Thread( + target=self._run_finalize_loop, + name=f"unified-t2s-finalize-{worker_index}", + daemon=True, + ) + for worker_index in range(self.finalize_executor.get_worker_count()) + ] + for finalize_thread in self.finalize_threads: + finalize_thread.start() + + def _notify_worker_state_change(self) -> None: + with self.condition: + self.condition.notify_all() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py new file mode 100644 index 00000000..da2c057a --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerJobRegistry, SchedulerPendingJob + + +class WorkerCompletionBridge: + def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def notify_done_future(self, job: SchedulerPendingJob) -> None: + if job.done_loop is None or job.done_future is None: + return + try: + job.done_loop.call_soon_threadsafe(self._resolve_done_future, job) + except RuntimeError: + pass + + def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.complete is None: + return + self.runtime_callbacks.complete(request_id, extra) + + def runtime_fail(self, request_id: str | None, error: str) -> None: + if request_id is None or self.runtime_callbacks.fail is None: + return + self.runtime_callbacks.fail(request_id, error) + + @staticmethod + def build_completed_job_result( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + audio_data: np.ndarray, + finished_at: float | None = None, + ) -> Dict[str, Any]: + finished_at = float(time.perf_counter() if finished_at is None else finished_at) + queue_wait_ms = 0.0 + if job.first_schedule_time is not None: + queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0) + worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0) + worker_residual_ms = max( + 0.0, + worker_total_ms + - queue_wait_ms + - job.prefill_ms + - job.merge_ms + - job.decode_ms + - job.finalize_wait_ms + - job.synth_ms, + ) + worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms) + job.sample_rate = int(sample_rate) + job.audio_data = audio_data + job.result_ready_time = finished_at + prepare_profile = dict(job.state.prepare_profile) + result = { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "decode_admission_wait_ms": float(job.admission_wait_ms), + "engine_policy_wait_ms": float(job.engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms), + "prepare_ms": job.prepare_wall_ms, + "prepare_wall_ms": job.prepare_wall_ms, + "prepare_profile_total_ms": job.prepare_profile_total_ms, + "prepare_profile": prepare_profile, + "queue_wait_ms": queue_wait_ms, + "prefill_ms": job.prefill_ms, + "merge_ms": job.merge_ms, + "decode_ms": job.decode_ms, + "finalize_wait_ms": job.finalize_wait_ms, + "synth_ms": job.synth_ms, + "worker_residual_ms": worker_residual_ms, + "worker_other_ms": worker_other_ms, + "worker_total_ms": worker_total_ms, + "decode_steps": int(job.decode_steps), + "sample_rate": int(sample_rate), + "media_type": job.media_type, + } + job.result = result + return result + + @staticmethod + def build_runtime_complete_payload( + job: SchedulerPendingJob, + item: T2SFinishedItem, + *, + sample_rate: int, + ) -> Dict[str, Any]: + return { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "sample_rate": int(sample_rate), + "worker_profile": dict(job.result or {}), + } + + def complete_job( + self, + job: SchedulerPendingJob, + *, + runtime_request_id: str | None, + runtime_extra: Optional[Dict[str, Any]] = None, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_complete(runtime_request_id, runtime_extra) + + def fail_job( + self, + job: SchedulerPendingJob, + *, + error: str, + remove_job: Callable[[], None] | None = None, + on_job_finished: Callable[[], None] | None = None, + notify_waiters: Callable[[], None] | None = None, + ) -> None: + job.error = str(error) + job.done_event.set() + self.notify_done_future(job) + if remove_job is not None: + remove_job() + if on_job_finished is not None: + on_job_finished() + if notify_waiters is not None: + notify_waiters() + self.runtime_fail(job.engine_request_id, str(error)) + + def complete_finalize_task( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + job: SchedulerPendingJob, + item: T2SFinishedItem, + sample_rate: int, + audio_data: np.ndarray, + ) -> None: + runtime_extra: Optional[Dict[str, Any]] = None + with condition: + if job_registry.get(item.request_id) is not job: + return + self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data) + runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate) + self.complete_job( + job, + runtime_request_id=job.engine_request_id, + runtime_extra=runtime_extra, + on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id), + notify_waiters=condition.notify_all, + ) + + def fail_jobs( + self, + *, + condition: threading.Condition, + job_registry: SchedulerJobRegistry, + request_ids: List[str], + error: str, + ) -> None: + if not request_ids: + return + with condition: + for request_id in request_ids: + job = job_registry.get(request_id) + if job is None: + continue + self.fail_job( + job, + error=error, + on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid), + ) + condition.notify_all() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py new file mode 100644 index 00000000..784f71d0 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Callable, Dict, List, Optional + +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( + T2SActiveBatch, + T2SFinishedItem, + decode_one_step, + merge_active_batches, + run_prefill_active_batch, +) +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerPendingJob + + +class WorkerDecodeExecutor: + def __init__(self, tts: TTS, max_steps: int) -> None: + self.tts = tts + self.max_steps = int(max_steps) + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + def execute_prefill_merge( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if not pending_jobs: + return { + "executed": False, + "active_batch": active_batch, + "pending_jobs": [], + "prefill_elapsed_s": 0.0, + "merge_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + admitted_finished: List[T2SFinishedItem] = [] + prefill_elapsed_s = 0.0 + merge_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + self._sync_device() + prefill_start = time.perf_counter() + mark_prefill_started(pending_jobs, prefill_start) + admitted_active_batch, admitted_finished = run_prefill_active_batch( + self.tts.t2s_model.model, + [job.state for job in pending_jobs], + max_steps=self.max_steps, + ) + self._sync_device() + prefill_elapsed_s = time.perf_counter() - prefill_start + if add_prefill_time is not None: + add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(admitted_finished) + merge_start = time.perf_counter() + active_batch = merge_active_batches( + self.tts.t2s_model.model, + active_batch, + admitted_active_batch, + ) + merge_elapsed_s = time.perf_counter() - merge_start + if add_merge_time is not None: + add_merge_time( + [] if active_batch is None else list(active_batch.request_ids), + merge_elapsed_s, + ) + except Exception as exc: + error = str(exc) + error_request_ids = [job.request_id for job in pending_jobs] + if finalize_error is not None: + finalize_error(error_request_ids, error) + return { + "executed": True, + "active_batch": active_batch, + "pending_jobs": list(pending_jobs), + "prefill_elapsed_s": float(prefill_elapsed_s), + "merge_elapsed_s": float(merge_elapsed_s), + "finished_items": list(admitted_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_step( + self, + *, + active_batch: Optional[T2SActiveBatch], + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + if active_batch is None: + return { + "executed": False, + "active_batch": None, + "request_ids": [], + "decode_elapsed_s": 0.0, + "finished_items": [], + "error": None, + "error_request_ids": [], + } + active_request_ids: List[str] = [] + step_finished: List[T2SFinishedItem] = [] + decode_elapsed_s = 0.0 + error: str | None = None + error_request_ids: List[str] = [] + try: + active_request_ids = [state.request_id for state in active_batch.states] + self._sync_device() + decode_start = time.perf_counter() + active_batch, step_finished = decode_one_step( + self.tts.t2s_model.model, + active_batch, + max_steps=self.max_steps, + ) + self._sync_device() + decode_elapsed_s = time.perf_counter() - decode_start + if add_decode_time is not None: + add_decode_time(active_request_ids, decode_elapsed_s) + if enqueue_finished is not None: + enqueue_finished(step_finished) + except Exception as exc: + error = str(exc) + error_request_ids = list(active_request_ids) + if finalize_error is not None: + finalize_error(error_request_ids, error) + active_batch = None + return { + "executed": True, + "active_batch": active_batch, + "request_ids": active_request_ids, + "decode_elapsed_s": float(decode_elapsed_s), + "finished_items": list(step_finished), + "error": error, + "error_request_ids": error_request_ids, + } + + def execute_decode_cycle( + self, + *, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None], + add_prefill_time: Callable[[List[str], float], None] | None, + add_merge_time: Callable[[List[str], float], None] | None, + add_decode_time: Callable[[List[str], float], None] | None, + enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None, + finalize_error: Callable[[List[str], str], None] | None, + ) -> Dict[str, Any]: + result = { + "executed": False, + "prefill_merge_executed": False, + "decode_step_executed": False, + "active_batch": active_batch, + "prefill_phase": {}, + "decode_phase": {}, + } + prefill_phase = self.execute_prefill_merge( + pending_jobs=list(pending_jobs), + active_batch=result["active_batch"], + mark_prefill_started=mark_prefill_started, + add_prefill_time=add_prefill_time, + add_merge_time=add_merge_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + prefill_executed = bool(prefill_phase.get("executed", False)) + result["prefill_phase"] = prefill_phase + result["active_batch"] = prefill_phase.get("active_batch") + if prefill_executed: + result["executed"] = True + result["prefill_merge_executed"] = True + decode_phase = self.execute_decode_step( + active_batch=result["active_batch"], + add_decode_time=add_decode_time, + enqueue_finished=enqueue_finished, + finalize_error=finalize_error, + ) + decode_executed = bool(decode_phase.get("executed", False)) + result["decode_phase"] = decode_phase + result["active_batch"] = decode_phase.get("active_batch") + if decode_executed: + result["executed"] = True + result["decode_step_executed"] = True + return result + + +class WorkerDecodeLegacyShell: + def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None: + self.condition = condition + self.micro_batch_wait_s = float(micro_batch_wait_s) + self.pending_jobs: List[SchedulerPendingJob] = [] + self.active_batch: T2SActiveBatch | None = None + + @staticmethod + def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None: + if active_batch is None: + return None + return { + "request_count": int(len(active_batch.request_ids)), + "request_ids": list(active_batch.request_ids), + "prefill_done": bool(active_batch.prefill_done), + "decode_step_index_max": ( + int(active_batch.step_indices.max().item()) + if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0 + else 0 + ), + } + + def current_backlog_locked(self) -> int: + running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids) + return int(len(self.pending_jobs) + running_requests) + + def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None: + self.pending_jobs.append(job) + + def snapshot_locked(self) -> Dict[str, Any]: + active_batch_summary = self._summarize_active_batch(self.active_batch) + executor_local_pending_jobs = int(len(self.pending_jobs)) + executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids)) + executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None) + return { + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "executor_local_active_batch": active_batch_summary, + } + + def is_idle_locked(self) -> bool: + return self.active_batch is None and not self.pending_jobs + + def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs and self.active_batch is None: + self.condition.wait(timeout=self.micro_batch_wait_s) + elif wait_for_batch and self.pending_jobs: + self.condition.wait(timeout=self.micro_batch_wait_s) + if not self.pending_jobs: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + with self.condition: + if not self.pending_jobs: + return [] + if wait_for_batch: + oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time) + if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s: + return [] + pending = list(self.pending_jobs) + self.pending_jobs.clear() + return pending + + def has_decode_runtime_work(self) -> bool: + with self.condition: + return bool(self.pending_jobs or self.active_batch is not None) + + def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]: + active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids) + decode_step_index_max = 0 + prefill_done = False + if self.active_batch is not None: + prefill_done = bool(self.active_batch.prefill_done) + if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0: + decode_step_index_max = int(self.active_batch.step_indices.max().item()) + return { + "pending_jobs": int(len(self.pending_jobs)), + "active_request_count": int(len(active_request_ids)), + "active_request_ids": active_request_ids[:32], + "prefill_done": bool(prefill_done), + "decode_step_index_max": int(decode_step_index_max), + "total_cycles": int(total_cycles), + "prefill_cycles": int(prefill_cycles), + "step_cycles": int(step_cycles), + "has_work": bool(self.pending_jobs or self.active_batch is not None), + "last_event": str(last_event), + "updated_at": float(time.perf_counter()), + } + + def run_prefill_merge_once_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_prefill_merge(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_step_once_nonblocking( + self, + *, + external_active_batch: Optional[T2SActiveBatch], + execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]], + ) -> Dict[str, Any]: + active_batch = self.active_batch if external_active_batch is None else external_active_batch + result = execute_decode_step(active_batch) + if external_active_batch is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + return result + + def run_decode_cycle_nonblocking( + self, + *, + external_pending_jobs: Optional[List[SchedulerPendingJob]], + external_active_batch: Optional[T2SActiveBatch], + execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]], + on_cycle_executed: Callable[[Dict[str, Any]], None] | None, + ) -> Dict[str, Any]: + pending_jobs = ( + list(external_pending_jobs) + if external_pending_jobs is not None + else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None) + ) + active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch + result = execute_decode_cycle(pending_jobs, active_batch) + if external_pending_jobs is None: + with self.condition: + self.active_batch = result.get("active_batch") + self.condition.notify_all() + if result.get("executed") and on_cycle_executed is not None: + on_cycle_executed(result) + return result + + def run_loop( + self, + *, + run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]], + ) -> None: + while True: + executed = run_decode_cycle_nonblocking() + if executed.get("executed"): + continue + wait_for_batch = self.active_batch is None + pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch) + if pending_jobs: + with self.condition: + self.pending_jobs = pending_jobs + self.pending_jobs + self.condition.notify_all() + continue + time.sleep(self.micro_batch_wait_s) + + +class WorkerDecodeRuntimeTracker: + def __init__( + self, + runtime_callbacks: RuntimeStateCallbacks | None = None, + ) -> None: + self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks() + self.total_cycles = 0 + self.prefill_cycles = 0 + self.step_cycles = 0 + + def get_counters(self) -> Dict[str, int]: + return { + "total_cycles": int(self.total_cycles), + "prefill_cycles": int(self.prefill_cycles), + "step_cycles": int(self.step_cycles), + } + + def record_cycle(self, result: Dict[str, Any]) -> None: + if not bool(result.get("executed")): + return + self.total_cycles += 1 + if bool(result.get("prefill_merge_executed")): + self.prefill_cycles += 1 + if bool(result.get("decode_step_executed")): + self.step_cycles += 1 + + def build_runtime_summary_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> Dict[str, Any]: + return legacy_shell.build_runtime_summary_locked( + total_cycles=int(self.total_cycles), + prefill_cycles=int(self.prefill_cycles), + step_cycles=int(self.step_cycles), + last_event=str(last_event), + ) + + def notify_runtime_update_locked( + self, + *, + legacy_shell: WorkerDecodeLegacyShell, + last_event: str, + ) -> None: + if self.runtime_callbacks.decode_runtime_update is None: + return + snapshot = self.build_runtime_summary_locked( + legacy_shell=legacy_shell, + last_event=last_event, + ) + self.runtime_callbacks.decode_runtime_update(snapshot) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py new file mode 100644 index 00000000..465f7a2c --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerExecutionMixin: + def execute_prefill_merge( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_prefill_merge( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_step( + self, + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + return self.decode_executor.execute_decode_step( + active_batch=active_batch, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + + def execute_decode_cycle( + self, + pending_jobs: List[SchedulerPendingJob], + active_batch: Optional[T2SActiveBatch], + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_executor.execute_decode_cycle( + pending_jobs=pending_jobs, + active_batch=active_batch, + mark_prefill_started=self._mark_prefill_started, + add_prefill_time=None if external_bookkeeping else self._add_prefill_time, + add_merge_time=None if external_bookkeeping else self._add_merge_time, + add_decode_time=None if external_bookkeeping else self._add_decode_time, + enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished, + finalize_error=None if external_bookkeeping else self._finalize_error, + ) + self._record_decode_runtime_cycle(result) + return result + + def run_prefill_merge_once_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("prefill_merge") + return result + + def run_decode_step_once_nonblocking( + self, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_step_once_nonblocking( + external_active_batch=external_active_batch, + execute_decode_step=lambda batch_state: self.execute_decode_step( + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + ) + if emit_runtime_state: + self._notify_decode_runtime_state("decode_step") + return result + + def run_decode_cycle_nonblocking( + self, + external_pending_jobs: Optional[List[SchedulerPendingJob]] = None, + external_active_batch: Optional[T2SActiveBatch] = None, + emit_runtime_state: bool = True, + external_bookkeeping: bool = False, + ) -> Dict[str, Any]: + result = self.decode_legacy_shell.run_decode_cycle_nonblocking( + external_pending_jobs=external_pending_jobs, + external_active_batch=external_active_batch, + execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle( + pending_jobs=batch_jobs, + active_batch=batch_state, + external_bookkeeping=external_bookkeeping, + ), + on_cycle_executed=None, + ) + if result.get("executed") and emit_runtime_state: + self._notify_decode_runtime_state("decode_cycle") + return result + + def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + try: + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = [] + with self.condition: + for task in tasks: + job = self.job_registry.get(task.request_id) + if job is None: + continue + jobs_and_items.append((job, task.item)) + if not jobs_and_items: + return + now = time.perf_counter() + for task in tasks: + self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0)) + for job, item in jobs_and_items: + self._runtime_update( + job.engine_request_id, + EngineStatus.FINALIZING, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + }, + ) + synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items) + with self.condition: + for job, _ in jobs_and_items: + tracked_job = self.job_registry.get(job.request_id) + if tracked_job is not None: + tracked_job.synth_ms += synth_ms + for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results): + self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data) + except Exception as exc: + self._finalize_error([task.request_id for task in tasks], str(exc)) + finally: + self.finalize_executor.end_execution(len(tasks)) + + def _run_finalize_loop(self) -> None: + while True: + tasks = self.finalize_executor.take_task_batch_blocking() + self.execute_finalize_tasks(tasks) + + def _run_loop(self) -> None: + self.decode_legacy_shell.run_loop( + run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking() + ) diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py new file mode 100644 index 00000000..bb9cb3cb --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import os +import threading +import time +from collections import deque +from typing import Any, Callable, Deque, Dict, List + +import numpy as np +import torch + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerFinalizeExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None, + ) -> None: + self.tts = tts + self.on_state_change = on_state_change + self.external_submit = external_submit + self.condition = threading.Condition() + self.pending_tasks: Deque[SchedulerFinalizeTask] = deque() + self.pending_peak = 0 + self.inflight = 0 + self.inflight_peak = 0 + self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1))) + self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower() + self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16))) + self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0) + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def get_worker_count(self) -> int: + return int(self.worker_count) + + def get_batch_policy(self) -> Dict[str, Any]: + return { + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_s": float(self.batch_wait_s), + } + + def get_pending_count(self) -> int: + with self.condition: + return int(len(self.pending_tasks)) + + def snapshot(self) -> Dict[str, Any]: + with self.condition: + return { + "finalize_pending": int(len(self.pending_tasks)), + "finalize_pending_peak": int(self.pending_peak), + "finalize_inflight": int(self.inflight), + "finalize_inflight_peak": int(self.inflight_peak), + "finalize_workers": int(self.worker_count), + "finalize_mode": str(self.finalize_mode), + "finalize_batch_max_items": int(self.batch_max_items), + "finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0), + } + + def is_idle(self) -> bool: + with self.condition: + return self.inflight <= 0 and not self.pending_tasks + + def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None: + if not tasks: + return + if self.external_submit is not None: + self.external_submit(tasks) + self._notify_state_change() + return + with self.condition: + for task in tasks: + self.pending_tasks.append(task) + self.pending_peak = max(self.pending_peak, len(self.pending_tasks)) + self.condition.notify_all() + self._notify_state_change() + + def begin_execution(self, task_count: int) -> None: + if task_count <= 0: + return + with self.condition: + self.inflight += int(task_count) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self.condition.notify_all() + self._notify_state_change() + + def end_execution(self, task_count: int) -> None: + with self.condition: + self.inflight = max(0, self.inflight - int(task_count)) + self.condition.notify_all() + self._notify_state_change() + + def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]: + with self.condition: + while not self.pending_tasks: + self.condition.wait() + selected_tasks = [self.pending_tasks.popleft()] + if self.finalize_mode == "sync" or self.tts.configs.use_vocoder: + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + batch_deadline = time.perf_counter() + self.batch_wait_s + while len(selected_tasks) < self.batch_max_items: + if not self.pending_tasks: + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + continue + first_task = selected_tasks[0] + matched_index = None + for index, task in enumerate(self.pending_tasks): + if abs(task.enqueued_time - first_task.enqueued_time) < 1.0: + matched_index = index + break + if matched_index is not None: + selected_tasks.append(self.pending_tasks[matched_index]) + del self.pending_tasks[matched_index] + continue + remaining = batch_deadline - time.perf_counter() + if remaining <= 0: + break + self.condition.wait(timeout=remaining) + self.inflight += len(selected_tasks) + self.inflight_peak = max(self.inflight_peak, self.inflight) + self._notify_state_change() + return selected_tasks + + def _sync_device(self) -> None: + try: + device_str = str(self.tts.configs.device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(self.tts.configs.device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + @staticmethod + def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]: + refer_specs = [] + if job.state.refer_spec is not None: + refer_specs.append(job.state.refer_spec) + refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or [])) + return refer_specs + + def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]: + audio_fragment = self.tts.synthesize_audio_request_local( + semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0), + phones=job.state.phones.detach().clone().unsqueeze(0), + prompt_semantic=job.state.prompt_semantic.detach().clone(), + prompt_phones=job.state.prompt_phones.detach().clone(), + refer_spec=[ + ( + refer_spec_item[0].detach().clone(), + None if refer_spec_item[1] is None else refer_spec_item[1].detach().clone(), + ) + for refer_spec_item in self._collect_job_refer_specs(job) + ], + raw_audio=job.state.raw_audio.detach().clone(), + raw_sr=int(job.state.raw_sr), + speed=float(job.speed_factor), + sample_steps=int(job.sample_steps), + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + return self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=bool(job.super_sampling), + ) + + def _synthesize_finished_audio_batch( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> List[tuple[int, np.ndarray]]: + semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items] + phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items] + refer_specs = [] + speeds = [] + sample_steps_list = [] + for job, _ in jobs_and_items: + refer_spec_group = self._collect_job_refer_specs(job) + if len(refer_spec_group) != 1: + raise ValueError("batched finalize 暂不支持单请求多参考音频") + refer_specs.append( + [( + refer_spec_group[0][0].detach().clone(), + None if refer_spec_group[0][1] is None else refer_spec_group[0][1].detach().clone(), + )] + ) + speeds.append(float(job.speed_factor)) + sample_steps_list.append(int(job.sample_steps)) + audio_fragments = self.tts.synthesize_audio_requests_local_batched( + semantic_tokens_list=semantic_tokens_list, + phones_list=phones_list, + refer_specs=refer_specs, + speeds=speeds, + sample_steps_list=sample_steps_list, + ) + output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"] + results: List[tuple[int, np.ndarray]] = [] + for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments): + results.append( + self.tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=float(job.speed_factor), + split_bucket=False, + fragment_interval=0.0, + super_sampling=bool(job.super_sampling), + ) + ) + return results + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + if not jobs_and_items: + return 0.0, [] + self._sync_device() + synth_start = time.perf_counter() + if ( + len(jobs_and_items) == 1 + or self.tts.configs.use_vocoder + or any(len(self._collect_job_refer_specs(job)) != 1 for job, _ in jobs_and_items) + ): + batch_results = [self._synthesize_finished_audio(job, item) for job, item in jobs_and_items] + else: + batch_results = self._synthesize_finished_audio_batch(jobs_and_items) + self._sync_device() + synth_ms = (time.perf_counter() - synth_start) * 1000.0 + return float(synth_ms), batch_results diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py new file mode 100644 index 00000000..9fb7c8d9 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import asyncio +import os +import time +from typing import Any, Callable, Dict, List + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState + + +class WorkerPrepareExecutor: + def __init__( + self, + tts: TTS, + on_state_change: Callable[[], None] | None = None, + ) -> None: + self.coordinator = PrepareCoordinator(tts) + self.on_state_change = on_state_change + + def _notify_state_change(self) -> None: + if self.on_state_change is None: + return + try: + self.on_state_change() + except Exception: + pass + + def snapshot(self) -> Dict[str, int]: + return dict(self.coordinator.snapshot()) + + def get_max_inflight(self) -> int: + return int(self.coordinator.snapshot().get("max_inflight", 0)) + + def get_batch_policy(self) -> Dict[str, int]: + return { + "prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))), + } + + def is_idle(self) -> bool: + return int(self.coordinator.snapshot().get("inflight", 0)) <= 0 + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + results = await asyncio.gather( + *[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs] + ) + return [state for state, _, _ in results] + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + try: + return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + finally: + self._notify_state_change() + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + try: + return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage) + finally: + self._notify_state_change() + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + try: + return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages) + finally: + self._notify_state_change() + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + try: + return await self.coordinator.prepare_gpu_text_phases_async(items) + finally: + self._notify_state_change() + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + try: + return self.coordinator.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, + ) + finally: + self._notify_state_change() + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + try: + return await self.coordinator.prepare_ref_spec_stages_async(phase_ones) + finally: + self._notify_state_change() + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + try: + self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result) + finally: + self._notify_state_change() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py new file mode 100644 index 00000000..de12f5e1 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import threading +import time +from typing import Any, Dict, List, Optional + +import numpy as np + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob + + +class WorkerRuntimeBookkeepingMixin: + def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None: + with self.condition: + for job in pending_jobs: + job.first_schedule_time = float(started_at) + self._runtime_update( + job.engine_request_id, + EngineStatus.GPU_PREPARING, + {"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)}, + ) + + def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.prefill_ms += delta_ms + + def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.merge_ms += delta_ms + + def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None: + delta_ms = float(elapsed_s) * 1000.0 + if not request_ids: + return + activate_request_ids: List[str] = [] + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + if job.decode_steps == 0: + activate_request_ids.append(job.engine_request_id) + job.decode_ms += delta_ms + job.decode_steps += 1 + for engine_request_id in activate_request_ids: + self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None) + + def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None: + if not request_ids: + return + with self.condition: + for request_id in request_ids: + job = self.job_registry.get(request_id) + if job is not None: + job.finalize_wait_ms += float(delta_ms) + + def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None: + if not items: + return + enqueued_at = time.perf_counter() + tasks: List[SchedulerFinalizeTask] = [] + with self.condition: + for item in items: + job = self.job_registry.get(item.request_id) + if job is not None: + self._runtime_update( + job.engine_request_id, + EngineStatus.READY_FOR_FINALIZE, + { + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + }, + ) + tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at)) + self.finalize_executor.enqueue_tasks(tasks) + + def begin_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.begin_execution(task_count) + + def end_finalize_execution(self, task_count: int) -> None: + self.finalize_executor.end_execution(task_count) + + def record_external_job_done(self, request_id: str) -> None: + with self.condition: + self.job_registry.mark_finished_and_remove(request_id) + self.condition.notify_all() + + def synthesize_finalize_jobs( + self, + jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]], + ) -> tuple[float, List[tuple[int, np.ndarray]]]: + return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items) + + def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None: + self.completion_bridge.complete_finalize_task( + condition=self.condition, + job_registry=self.job_registry, + job=job, + item=item, + sample_rate=sample_rate, + audio_data=audio_data, + ) + + def _finalize_error(self, request_ids: List[str], error: str) -> None: + self.completion_bridge.fail_jobs( + condition=self.condition, + job_registry=self.job_registry, + request_ids=request_ids, + error=error, + ) + + @staticmethod + def _resolve_done_future(job: SchedulerPendingJob) -> None: + future = job.done_future + if future is None or future.done(): + return + future.set_result(job) + + def _notify_done_future(self, job: SchedulerPendingJob) -> None: + self.completion_bridge.notify_done_future(job) + + def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None: + if request_id is None or self.runtime_callbacks.update is None: + return + self.runtime_callbacks.update(request_id, status, extra) + + def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None: + self.completion_bridge.runtime_complete(request_id, extra) + + def _runtime_fail(self, request_id: str | None, error: str) -> None: + self.completion_bridge.runtime_fail(request_id, error) + + def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]: + return self.decode_runtime_tracker.build_runtime_summary_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _notify_decode_runtime_state(self, last_event: str) -> None: + with self.condition: + self.decode_runtime_tracker.notify_runtime_update_locked( + legacy_shell=self.decode_legacy_shell, + last_event=str(last_event), + ) + + def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None: + with self.condition: + self.decode_runtime_tracker.record_cycle(result) + + def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch) + + def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]: + return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch) + + def has_decode_runtime_work(self) -> bool: + return self.decode_legacy_shell.has_decode_runtime_work() diff --git a/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py new file mode 100644 index 00000000..2ac636fe --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +import asyncio +import threading +import time +from typing import Any, Dict, List + +from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState +from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerPendingJob + + +class WorkerSubmitLifecycleMixin: + def _current_decode_backlog_locked(self) -> int: + return self.decode_legacy_shell.current_backlog_locked() + + def get_micro_batch_wait_s(self) -> float: + return float(self.micro_batch_wait_s) + + def is_engine_decode_control_enabled(self) -> bool: + return bool(self.engine_decode_control_enabled) + + def get_prepare_max_inflight(self) -> int: + return int(self.prepare_executor.get_max_inflight()) + + def get_capacity_limits(self) -> Dict[str, int]: + return { + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + } + + def get_finalize_batch_policy(self) -> Dict[str, Any]: + return dict(self.finalize_executor.get_batch_policy()) + + def get_prepare_batch_policy(self) -> Dict[str, int]: + return dict(self.prepare_executor.get_batch_policy()) + + def get_decode_runtime_counters(self) -> Dict[str, int]: + with self.condition: + return self.decode_runtime_tracker.get_counters() + + def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]: + decode_backlog = self._current_decode_backlog_locked() + finalize_pending = int(self.finalize_executor.get_pending_count()) + prepare_inflight = int(self.prepare_executor.snapshot()["inflight"]) + blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max + blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max + return ( + not blocked_decode and not blocked_finalize, + { + "decode_backlog": decode_backlog, + "finalize_pending": finalize_pending, + "prepare_inflight": prepare_inflight, + "decode_backlog_max": int(self.decode_backlog_max), + "finalize_pending_max": int(self.finalize_pending_max), + }, + ) + + def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]: + start = time.perf_counter() + deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec))) + while True: + with self.condition: + allowed, snapshot = self._can_accept_submit_locked() + if allowed: + return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot + if deadline is not None and time.perf_counter() >= deadline: + raise TimeoutError( + "scheduler submit admission timeout " + f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})" + ) + self.condition.wait(timeout=self.micro_batch_wait_s) + + def _admission_snapshot_locked(self) -> Dict[str, int]: + _, snapshot = self._can_accept_submit_locked() + return snapshot + + async def submit_async( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + super_sampling: bool, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + return await asyncio.to_thread( + self.submit, + state, + speed_factor, + sample_steps, + media_type, + super_sampling, + prepare_wall_ms, + prepare_profile_total_ms, + done_loop, + done_future, + engine_request_id, + timeout_sec, + skip_capacity_wait, + admission_wait_ms_override, + admission_snapshot_override, + engine_policy_wait_ms, + engine_dispatch_wait_ms, + enqueue_pending, + ) + + def snapshot(self) -> dict: + with self.condition: + prepare_state = self.prepare_executor.snapshot() + finalize_state = self.finalize_executor.snapshot() + shell_state = self.decode_legacy_shell.snapshot_locked() + decode_runtime_counters = self.decode_runtime_tracker.get_counters() + engine_owned_decode_state = bool(self.engine_decode_control_enabled) + active_batch_summary = shell_state.get("executor_local_active_batch") + executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0)) + executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0)) + executor_local_has_work = bool(shell_state.get("executor_local_has_work", False)) + return { + "pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs, + "running_requests": 0 if engine_owned_decode_state else executor_local_running_requests, + "engine_decode_control_enabled": bool(self.engine_decode_control_enabled), + "legacy_state_owner_mode": not engine_owned_decode_state, + "decode_state_owner": "engine" if engine_owned_decode_state else "worker", + "decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work, + "executor_local_pending_jobs": executor_local_pending_jobs, + "executor_local_running_requests": executor_local_running_requests, + "executor_local_has_work": executor_local_has_work, + "decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)), + "decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)), + "decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)), + "prepare_inflight": prepare_state["inflight"], + "prepare_peak_inflight": prepare_state["peak_inflight"], + "prepare_max_inflight": prepare_state.get("max_inflight", 0), + "prepare_state": dict(prepare_state), + **finalize_state, + "decode_backlog_max": self.decode_backlog_max, + "finalize_pending_max": self.finalize_pending_max, + "active_batch": {} if engine_owned_decode_state else active_batch_summary, + "executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None, + "total_submitted": self.job_registry.submitted_count(), + "total_finished": self.job_registry.finished_count(), + "drained": self.is_drained(), + } + + def is_drained(self) -> bool: + with self.condition: + return ( + self.decode_legacy_shell.is_idle_locked() + and self.job_registry.is_empty() + and self.prepare_executor.is_idle() + and self.finalize_executor.is_idle() + ) + + def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool: + deadline = time.perf_counter() + max(0.0, timeout_sec) + while time.perf_counter() < deadline: + if self.is_drained(): + return True + time.sleep(poll_interval_sec) + return self.is_drained() + + def submit( + self, + state: T2SRequestState, + speed_factor: float, + sample_steps: int, + media_type: str, + super_sampling: bool, + prepare_wall_ms: float, + prepare_profile_total_ms: float, + done_loop: asyncio.AbstractEventLoop | None = None, + done_future: asyncio.Future | None = None, + engine_request_id: str | None = None, + timeout_sec: float | None = None, + skip_capacity_wait: bool = False, + admission_wait_ms_override: float | None = None, + admission_snapshot_override: Dict[str, Any] | None = None, + engine_policy_wait_ms: float = 0.0, + engine_dispatch_wait_ms: float = 0.0, + enqueue_pending: bool = True, + ) -> SchedulerPendingJob: + if skip_capacity_wait: + with self.condition: + admission_snapshot = ( + dict(admission_snapshot_override) + if admission_snapshot_override is not None + else dict(self._admission_snapshot_locked()) + ) + admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override) + else: + admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec) + job = SchedulerPendingJob( + request_id=state.request_id, + state=state, + done_event=threading.Event(), + done_loop=done_loop, + done_future=done_future, + enqueue_time=time.perf_counter(), + speed_factor=float(speed_factor), + sample_steps=int(sample_steps), + media_type=media_type, + super_sampling=bool(super_sampling), + admission_wait_ms=float(admission_wait_ms), + engine_policy_wait_ms=float(engine_policy_wait_ms), + engine_dispatch_wait_ms=float(engine_dispatch_wait_ms), + prepare_wall_ms=float(prepare_wall_ms), + prepare_profile_total_ms=float(prepare_profile_total_ms), + engine_request_id=engine_request_id or state.request_id, + ) + with self.condition: + self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled) + if enqueue_pending: + self.decode_legacy_shell.enqueue_pending_job_locked(job) + self.condition.notify_all() + if enqueue_pending: + self._notify_decode_runtime_state("submit") + self._runtime_update( + job.engine_request_id, + EngineStatus.QUEUED, + { + "scheduler_request_id": job.request_id, + "decode_admission_wait_ms": float(admission_wait_ms), + "engine_policy_wait_ms": float(engine_policy_wait_ms), + "engine_dispatch_wait_ms": float(engine_dispatch_wait_ms), + "admission_snapshot": dict(admission_snapshot), + }, + ) + return job + + async def prepare_state_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at) + + async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]: + return await self.prepare_executor.prepare_states_batch_async(specs) + + async def prepare_cpu_stage_profiled_async( + self, + spec: SchedulerRequestSpec, + prepare_submit_at: float, + ) -> PreparedCpuStage: + return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at) + + async def prepare_gpu_stage_profiled_async( + self, + cpu_stage: PreparedCpuStage, + ) -> tuple[T2SRequestState, float, float]: + return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage) + + async def prepare_gpu_stages_profiled_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[tuple[T2SRequestState, float, float] | Exception]: + return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages) + + async def prepare_gpu_audio_phases_async( + self, + cpu_stages: List[PreparedCpuStage], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages) + + async def prepare_gpu_text_phases_async( + self, + items: List[tuple[PreparedCpuStage, Dict[str, Any]]], + ) -> List[Dict[str, Any] | Exception]: + return await self.prepare_executor.prepare_gpu_text_phases_async(items) + + def build_gpu_prepare_result_from_phases( + self, + cpu_stage: PreparedCpuStage, + phase_one: Dict[str, Any], + phase_two: Dict[str, Any], + extra_profile: Dict[str, float] | None = None, + ) -> tuple[T2SRequestState, float, float]: + return self.prepare_executor.build_gpu_prepare_result_from_phases( + cpu_stage, + phase_one, + phase_two, + extra_profile=extra_profile, + ) + + async def prepare_ref_spec_stages_async( + self, + phase_ones: List[Dict[str, Any]], + ) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]: + return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones) + + def apply_ref_spec_result_to_state( + self, + state: T2SRequestState, + ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]], + ) -> None: + self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 348ddb3f..c6d147cf 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -2,6 +2,7 @@ import warnings warnings.filterwarnings("ignore") import math +from typing import List import torch from torch import nn @@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + @torch.no_grad() + def decode_batched_request_local( + self, + codes: torch.Tensor, + code_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + refer_list: List[torch.Tensor], + noise_scale: float = 0.5, + speed: float = 1, + sv_emb: torch.Tensor | None = None, + ): + batch_size = int(codes.size(1)) + if batch_size <= 0: + raise ValueError("decode_batched_request_local 收到空 batch") + if len(refer_list) != batch_size: + raise ValueError("refer_list 数量与 batch size 不一致") + + refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device) + max_refer_len = int(refer_lengths.max().item()) + refer_batch = torch.zeros( + (batch_size, int(refer_list[0].size(1)), max_refer_len), + dtype=refer_list[0].dtype, + device=codes.device, + ) + for batch_index, refer in enumerate(refer_list): + refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype) + if self.version == "v1": + ge = self.ref_enc(refer_batch * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + if sv_emb is None: + raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb") + ge = ge + self.sv_emb(sv_emb).unsqueeze(-1) + ge = self.prelu(ge) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") + y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2 + text_lengths = text_lengths.to(device=text.device, dtype=torch.long) + x, m_p, logs_p, y_mask, _, _ = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=ge, reverse=True) + audio = self.dec((z * y_mask)[:, :, :], g=ge) + upsample_factor = 1 + for up_layer in self.dec.ups: + stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride) + upsample_factor *= int(stride) + audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor) + return audio, audio_lengths + @torch.no_grad() def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None): diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index dcce0d96..a5d32490 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -1,5 +1,6 @@ import os import re +import time import cn2an from pypinyin import lazy_pinyin, Style @@ -77,6 +78,205 @@ def g2p(text): return phones, word2ph +def _prepare_g2p_segments(segments): + prepared_segments = [] + batch_inputs = [] + for segment in segments: + processed_segment = re.sub("[a-zA-Z]+", "", segment) + seg_cut = psg.lcut(processed_segment) + seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) + prepared_segments.append( + { + "segment": processed_segment, + "seg_cut": seg_cut, + } + ) + if processed_segment: + batch_inputs.append(processed_segment) + return prepared_segments, batch_inputs + + +def _build_segment_from_g2pw(segment: str, seg_cut, pinyins): + phones_list = [] + word2ph = [] + initials = [] + finals = [] + pre_word_length = 0 + for word, pos in seg_cut: + sub_initials = [] + sub_finals = [] + now_word_length = pre_word_length + len(word) + + if pos == "eng": + pre_word_length = now_word_length + continue + + word_pinyins = pinyins[pre_word_length:now_word_length] + word_pinyins = correct_pronunciation(word, word_pinyins) + + for pinyin in word_pinyins: + if pinyin[0].isalpha(): + sub_initials.append(to_initials(pinyin)) + sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True)) + else: + sub_initials.append(pinyin) + sub_finals.append(pinyin) + + pre_word_length = now_word_length + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + + initials = sum(initials, []) + finals = sum(finals, []) + for c, v in zip(initials, finals): + raw_pinyin = c + v + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + + pinyin = c + v_without_tone + assert tone in "12345" + + if c: + v_rep_map = { + "uei": "ui", + "iou": "iu", + "uen": "un", + } + if v_without_tone in v_rep_map.keys(): + pinyin = c + v_rep_map[v_without_tone] + else: + pinyin_rep_map = { + "ing": "ying", + "i": "yi", + "in": "yin", + "u": "wu", + } + if pinyin in pinyin_rep_map.keys(): + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = { + "v": "yu", + "e": "e", + "i": "y", + "u": "w", + } + if pinyin[0] in single_rep_map.keys(): + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + + phones_list += phone + return phones_list, word2ph + + +def _build_segment_without_g2pw(segment: str, seg_cut): + initials = [] + finals = [] + for word, pos in seg_cut: + if pos == "eng": + continue + sub_initials, sub_finals = _get_initials_finals(word) + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) + initials.append(sub_initials) + finals.append(sub_finals) + phones_list = [] + word2ph = [] + for c, v in zip(sum(initials, []), sum(finals, [])): + raw_pinyin = c + v + if c == v: + assert c in punctuation + phone = [c] + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + pinyin = c + v_without_tone + assert tone in "12345" + if c: + v_rep_map = {"uei": "ui", "iou": "iu", "uen": "un"} + if v_without_tone in v_rep_map: + pinyin = c + v_rep_map[v_without_tone] + else: + pinyin_rep_map = {"ing": "ying", "i": "yi", "in": "yin", "u": "wu"} + if pinyin in pinyin_rep_map: + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = {"v": "yu", "e": "e", "i": "y", "u": "w"} + if pinyin[0] in single_rep_map: + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin) + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") + new_v = new_v + tone + phone = [new_c, new_v] + word2ph.append(len(phone)) + phones_list += phone + return phones_list, word2ph + + +def g2p_segments(segments, return_profile: bool = False): + prepare_start = time.perf_counter() + prepared_segments, batch_inputs = _prepare_g2p_segments(segments) + profile = { + "g2pw_prepare_ms": 0.0, + "g2pw_predict_ms": 0.0, + "g2pw_post_ms": 0.0, + "g2pw_runtime_total_ms": 0.0, + "g2pw_runtime_queue_wait_ms": 0.0, + "g2pw_runtime_collect_wait_ms": 0.0, + "g2pw_runtime_run_ms": 0.0, + "g2pw_runtime_batch_rows": 0.0, + "g2pw_runtime_batch_requests": 0.0, + "g2pw_runtime_pool_workers": 0.0, + "g2pw_runtime_shard_index": 0.0, + } + profile["g2pw_prepare_ms"] = float((time.perf_counter() - prepare_start) * 1000.0) + if is_g2pw and batch_inputs: + converter = g2pw._g2pw + if hasattr(converter, "predict_sentences_with_profile"): + g2pw_batch_results, predict_profile = converter.predict_sentences_with_profile(batch_inputs) + for key, value in dict(predict_profile or {}).items(): + profile[key] = float(value) + else: + predict_start = time.perf_counter() + g2pw_batch_results = converter(batch_inputs) + profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_start) * 1000.0) + else: + g2pw_batch_results = [] + post_start = time.perf_counter() + results = [] + batch_cursor = 0 + for item in prepared_segments: + segment = item["segment"] + if not segment: + results.append(([], [], segment)) + continue + if not is_g2pw: + phones, word2ph = _build_segment_without_g2pw(segment, item["seg_cut"]) + results.append((phones, word2ph, segment)) + continue + pinyins = g2pw_batch_results[batch_cursor] + batch_cursor += 1 + phones, word2ph = _build_segment_from_g2pw(segment, item["seg_cut"], pinyins) + results.append((phones, word2ph, segment)) + profile["g2pw_post_ms"] = float((time.perf_counter() - post_start) * 1000.0) + profile["g2pw_total_ms"] = float(profile["g2pw_prepare_ms"] + profile["g2pw_predict_ms"] + profile["g2pw_post_ms"]) + if return_profile: + return results, profile + return results + + def _get_initials_finals(word): initials = [] finals = [] @@ -180,118 +380,9 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> def _g2p(segments): phones_list = [] word2ph = [] - for seg in segments: - pinyins = [] - # Replace all English words in the sentence - seg = re.sub("[a-zA-Z]+", "", seg) - seg_cut = psg.lcut(seg) - seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) - initials = [] - finals = [] - - if not is_g2pw: - for word, pos in seg_cut: - if pos == "eng": - continue - sub_initials, sub_finals = _get_initials_finals(word) - sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) - # 儿化 - sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) - initials.append(sub_initials) - finals.append(sub_finals) - # assert len(sub_initials) == len(sub_finals) == len(word) - initials = sum(initials, []) - finals = sum(finals, []) - print("pypinyin结果", initials, finals) - else: - # g2pw采用整句推理 - pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3) - - pre_word_length = 0 - for word, pos in seg_cut: - sub_initials = [] - sub_finals = [] - now_word_length = pre_word_length + len(word) - - if pos == "eng": - pre_word_length = now_word_length - continue - - word_pinyins = pinyins[pre_word_length:now_word_length] - - # 多音字消歧 - word_pinyins = correct_pronunciation(word, word_pinyins) - - for pinyin in word_pinyins: - if pinyin[0].isalpha(): - sub_initials.append(to_initials(pinyin)) - sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True)) - else: - sub_initials.append(pinyin) - sub_finals.append(pinyin) - - pre_word_length = now_word_length - sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) - # 儿化 - sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos) - initials.append(sub_initials) - finals.append(sub_finals) - - initials = sum(initials, []) - finals = sum(finals, []) - # print("g2pw结果",initials,finals) - - for c, v in zip(initials, finals): - raw_pinyin = c + v - # NOTE: post process for pypinyin outputs - # we discriminate i, ii and iii - if c == v: - assert c in punctuation - phone = [c] - word2ph.append(1) - else: - v_without_tone = v[:-1] - tone = v[-1] - - pinyin = c + v_without_tone - assert tone in "12345" - - if c: - # 多音节 - v_rep_map = { - "uei": "ui", - "iou": "iu", - "uen": "un", - } - if v_without_tone in v_rep_map.keys(): - pinyin = c + v_rep_map[v_without_tone] - else: - # 单音节 - pinyin_rep_map = { - "ing": "ying", - "i": "yi", - "in": "yin", - "u": "wu", - } - if pinyin in pinyin_rep_map.keys(): - pinyin = pinyin_rep_map[pinyin] - else: - single_rep_map = { - "v": "yu", - "e": "e", - "i": "y", - "u": "w", - } - if pinyin[0] in single_rep_map.keys(): - pinyin = single_rep_map[pinyin[0]] + pinyin[1:] - - assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) - new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") - new_v = new_v + tone - phone = [new_c, new_v] - word2ph.append(len(phone)) - - phones_list += phone + for phones, item_word2ph, _segment in g2p_segments(segments): + phones_list += phones + word2ph += item_word2ph return phones_list, word2ph diff --git a/GPT_SoVITS/text/g2pw/cuda_api.py b/GPT_SoVITS/text/g2pw/cuda_api.py new file mode 100644 index 00000000..881d6123 --- /dev/null +++ b/GPT_SoVITS/text/g2pw/cuda_api.py @@ -0,0 +1,685 @@ +import ctypes +import fcntl +import os +import subprocess +import threading +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Deque, Dict, List, Tuple + +import numpy as np + +from .onnx_api import _G2PWBaseOnnxConverter + + +class G2PWCudaError(RuntimeError): + pass + + +@dataclass +class G2PWBatchTask: + model_input: Dict[str, np.ndarray] + created_at: float = field(default_factory=time.perf_counter) + enqueued_at: float = 0.0 + done_event: threading.Event = field(default_factory=threading.Event) + output: np.ndarray | None = None + profile: Dict[str, float] = field(default_factory=dict) + error: Exception | None = None + + +_ROOT_DIR = Path(__file__).resolve().parents[3] +_PACKAGE_DIR = Path(__file__).resolve().parent +_OUTPUT_DIR = _ROOT_DIR / "outputs" / "g2pw_cuda_bridge" +_WRAPPER_SOURCE = _PACKAGE_DIR / "g2pw_cuda_bridge.cpp" +_LOCK_PATH = _OUTPUT_DIR / "build.lock" + + +def _env_flag(name: str, default: bool) -> int: + raw = os.environ.get(name) + if raw is None: + return 1 if default else 0 + return 0 if raw.strip().lower() in {"0", "false", "no", "off"} else 1 + + +def _env_int(name: str, default: int) -> int: + raw = os.environ.get(name) + if raw is None or raw.strip() == "": + return int(default) + return int(raw) + + +def _resolve_cuda_root() -> Path: + env_root = os.environ.get("GPTSOVITS_G2PW_CUDA_ROOT", "").strip() + candidates = [ + env_root, + _ROOT_DIR / "third_party" / "g2pw-cu", + ] + for candidate in candidates: + if not candidate: + continue + path = Path(candidate).expanduser().resolve() + if path.exists(): + return path + checked = [ + str(Path(candidate).expanduser().resolve()) + for candidate in candidates + if str(candidate).strip() != "" + ] + raise G2PWCudaError( + "Cannot locate g2pw-cu root. " + "Expected one of: " + f"{checked}. " + "Recommended: clone https://github.com/baicai-1145/g2pw-cu.git into " + f"{(_ROOT_DIR / 'third_party' / 'g2pw-cu').as_posix()} " + "or set GPTSOVITS_G2PW_CUDA_ROOT explicitly." + ) + + +def _resolve_runtime_paths() -> tuple[Path, Path, Path]: + cuda_root = _resolve_cuda_root() + runtime_lib = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_RUNTIME_LIB", str(cuda_root / "build" / "libg2pw_runtime.so")) + ).expanduser() + manifest_path = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_MANIFEST", str(cuda_root / "artifacts" / "model" / "manifest.txt")) + ).expanduser() + weights_path = Path( + os.environ.get("GPTSOVITS_G2PW_CUDA_WEIGHTS", str(cuda_root / "artifacts" / "model" / "weights.bin")) + ).expanduser() + for path in (runtime_lib, manifest_path, weights_path): + if not path.exists(): + raise G2PWCudaError(f"Missing g2pw-cu artifact: {path}") + return runtime_lib.resolve(), manifest_path.resolve(), weights_path.resolve() + + +def _build_bridge(wrapper_output: Path, runtime_lib: Path) -> None: + _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + compile_cmd = [ + os.environ.get("CXX", "g++"), + "-O3", + "-std=c++17", + "-shared", + "-fPIC", + str(_WRAPPER_SOURCE), + "-I", + str(runtime_lib.parent.parent / "include"), + "-L", + str(runtime_lib.parent), + "-lg2pw_runtime", + f"-Wl,-rpath,{runtime_lib.parent}", + "-o", + str(wrapper_output), + ] + result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False) + if result.returncode != 0: + raise G2PWCudaError( + "Failed to build g2pw-cu bridge:\n" + f"cmd={' '.join(compile_cmd)}\n" + f"stdout={result.stdout}\n" + f"stderr={result.stderr}" + ) + + +def _ensure_bridge_built(runtime_lib: Path) -> Path: + wrapper_output = _OUTPUT_DIR / "g2pw_cuda_bridge.so" + _OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + with _LOCK_PATH.open("w", encoding="utf-8") as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + needs_build = not wrapper_output.exists() + if not needs_build: + so_mtime = wrapper_output.stat().st_mtime + needs_build = so_mtime < _WRAPPER_SOURCE.stat().st_mtime or so_mtime < runtime_lib.stat().st_mtime + if needs_build: + tmp_output = wrapper_output.with_suffix(".tmp.so") + if tmp_output.exists(): + tmp_output.unlink() + _build_bridge(tmp_output, runtime_lib) + tmp_output.replace(wrapper_output) + return wrapper_output + + +def _load_bridge(): + runtime_lib, manifest_path, weights_path = _resolve_runtime_paths() + bridge_path = _ensure_bridge_built(runtime_lib) + global_mode = getattr(ctypes, "RTLD_GLOBAL", getattr(os, "RTLD_GLOBAL", 0)) + ctypes.CDLL(str(runtime_lib), mode=global_mode) + lib = ctypes.CDLL(str(bridge_path)) + lib.g2pw_runtime_create.argtypes = [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + lib.g2pw_runtime_create.restype = ctypes.c_void_p + lib.g2pw_runtime_destroy.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_destroy.restype = None + lib.g2pw_runtime_last_error.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_last_error.restype = ctypes.c_char_p + lib.g2pw_runtime_num_labels.argtypes = [ctypes.c_void_p] + lib.g2pw_runtime_num_labels.restype = ctypes.c_int + lib.g2pw_runtime_run.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_void_p, + ] + lib.g2pw_runtime_run.restype = ctypes.c_int + return lib, manifest_path, weights_path, runtime_lib + + +def _gemm_precision_value() -> int: + precision = os.environ.get("GPTSOVITS_G2PW_CUDA_GEMM_PRECISION", "fp32").strip().lower() + if precision == "fp16": + return 1 + if precision == "bf16": + return 2 + return 0 + + +class G2PWRuntimeWrapper: + def __init__(self, shard_index: int = 0) -> None: + self.lib, self.manifest_path, self.weights_path, self.runtime_lib = _load_bridge() + self.shard_index = int(shard_index) + self.device_ordinal = _env_int("GPTSOVITS_G2PW_CUDA_DEVICE", 0) + self.allow_tensor_cores = _env_flag("GPTSOVITS_G2PW_CUDA_ALLOW_TENSOR_CORES", False) + self.use_cublaslt_bias_epilogue = _env_flag("GPTSOVITS_G2PW_CUDA_USE_CUBLASLT_BIAS_EPILOGUE", False) + self.enable_profiling = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_PROFILE", False) + self.enable_cuda_graph = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_GRAPH", True) + self.dump_graph_cache_stats = _env_flag("GPTSOVITS_G2PW_CUDA_DUMP_GRAPH_CACHE_STATS", False) + self.full_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_FULL_GRAPH_CACHE_LIMIT", 0) + self.tail_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_TAIL_GRAPH_CACHE_LIMIT", 0) + self.gemm_precision = _gemm_precision_value() + self.lock = threading.Lock() + self.handle = None + self.max_batch_size = 0 + self.max_seq_len = 0 + self.num_labels = 0 + self.batch_enabled = _env_flag("GPTSOVITS_G2PW_CUDA_BATCHING", True) != 0 + self.batch_window_s = max(0.0, float(_env_int("GPTSOVITS_G2PW_CUDA_BATCH_WINDOW_MS", 1)) / 1000.0) + self.batch_max_requests = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_REQUESTS", 64)) + self.batch_max_rows = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_ROWS", 96)) + self.batch_max_tokens = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_TOKENS", 4096)) + self.batch_condition = threading.Condition() + self.pending_tasks: Deque[G2PWBatchTask] = deque() + self.batch_total_tasks = 0 + self.batch_total_batches = 0 + self.batch_total_rows = 0 + self.batch_total_queue_wait_ms = 0.0 + self.batch_queue_wait_peak_ms = 0.0 + self.batch_total_collect_wait_ms = 0.0 + self.batch_collect_wait_peak_ms = 0.0 + self.batch_total_run_ms = 0.0 + self.batch_run_peak_ms = 0.0 + self.batch_rows_peak = 0 + self.batch_requests_peak = 0 + self.batch_pending_peak = 0 + self.closed = False + self._ensure_capacity( + batch_size=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_BATCH_SIZE", 256)), + seq_len=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_SEQ_LEN", 128)), + ) + self.batch_worker = None + if self.batch_enabled: + self.batch_worker = threading.Thread( + target=self._batch_loop, + name=f"g2pw-cuda-batch-worker-{self.shard_index}", + daemon=True, + ) + self.batch_worker.start() + + def _sync_runtime_env_overrides(self) -> None: + os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0" + os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0" + os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0" + os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit)) + os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit)) + os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0" + os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0" + os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32") + + def _destroy_handle(self) -> None: + if self.handle: + self.lib.g2pw_runtime_destroy(self.handle) + self.handle = None + + def close(self) -> None: + with self.batch_condition: + self.closed = True + self.batch_condition.notify_all() + self._destroy_handle() + + def __del__(self): + try: + self.close() + except Exception: + pass + + def _last_error(self) -> str: + if not self.handle: + return "uninitialized runtime" + message = self.lib.g2pw_runtime_last_error(self.handle) + return "" if not message else message.decode("utf-8", errors="replace") + + def _create_handle(self, batch_size: int, seq_len: int) -> None: + self._sync_runtime_env_overrides() + new_handle = self.lib.g2pw_runtime_create( + str(self.manifest_path).encode("utf-8"), + str(self.weights_path).encode("utf-8"), + int(self.device_ordinal), + int(batch_size), + int(seq_len), + int(self.full_graph_cache_limit), + int(self.tail_graph_cache_limit), + int(self.allow_tensor_cores), + int(self.use_cublaslt_bias_epilogue), + int(self.enable_profiling), + int(self.enable_cuda_graph), + int(self.dump_graph_cache_stats), + int(self.gemm_precision), + ) + if not new_handle: + raise G2PWCudaError("g2pw-cu returned null runtime handle") + self.handle = new_handle + self.max_batch_size = int(batch_size) + self.max_seq_len = int(seq_len) + self.num_labels = int(self.lib.g2pw_runtime_num_labels(self.handle)) + last_error = self._last_error() + if self.num_labels <= 0 or last_error: + self.close() + raise G2PWCudaError(f"Failed to initialize g2pw-cu runtime: {last_error or 'num_labels <= 0'}") + + def _ensure_capacity(self, batch_size: int, seq_len: int) -> None: + target_batch = max(1, int(batch_size)) + target_seq = max(1, int(seq_len)) + if self.handle and target_batch <= self.max_batch_size and target_seq <= self.max_seq_len: + return + next_batch = max(target_batch, self.max_batch_size * 2 if self.max_batch_size else 0) + next_seq = max(target_seq, self.max_seq_len * 2 if self.max_seq_len else 0) + self._destroy_handle() + self._create_handle(batch_size=next_batch, seq_len=next_seq) + + @staticmethod + def _normalize_model_input(model_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + input_ids = np.ascontiguousarray(model_input["input_ids"], dtype=np.int64) + token_type_ids = np.ascontiguousarray(model_input["token_type_ids"], dtype=np.int64) + attention_masks = np.ascontiguousarray(model_input["attention_masks"], dtype=np.int64) + phoneme_masks = np.ascontiguousarray(model_input["phoneme_masks"], dtype=np.float32) + char_ids = np.ascontiguousarray(model_input["char_ids"], dtype=np.int64) + position_ids = np.ascontiguousarray(model_input["position_ids"], dtype=np.int64) + batch_size = int(char_ids.shape[0]) + if input_ids.shape[0] == 1 and batch_size > 1: + input_ids = np.ascontiguousarray(np.repeat(input_ids, batch_size, axis=0), dtype=np.int64) + token_type_ids = np.ascontiguousarray(np.repeat(token_type_ids, batch_size, axis=0), dtype=np.int64) + attention_masks = np.ascontiguousarray(np.repeat(attention_masks, batch_size, axis=0), dtype=np.int64) + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_masks": attention_masks, + "phoneme_masks": phoneme_masks, + "char_ids": char_ids, + "position_ids": position_ids, + } + + def _run_direct(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + normalized = self._normalize_model_input(model_input) + input_ids = normalized["input_ids"] + token_type_ids = normalized["token_type_ids"] + attention_masks = normalized["attention_masks"] + phoneme_masks = normalized["phoneme_masks"] + char_ids = normalized["char_ids"] + position_ids = normalized["position_ids"] + batch_size = int(char_ids.shape[0]) + seq_len = int(input_ids.shape[1]) + probs = np.empty((batch_size, self.num_labels), dtype=np.float32) + with self.lock: + self._ensure_capacity(batch_size=batch_size, seq_len=seq_len) + status = self.lib.g2pw_runtime_run( + self.handle, + input_ids.ctypes.data_as(ctypes.c_void_p), + token_type_ids.ctypes.data_as(ctypes.c_void_p), + attention_masks.ctypes.data_as(ctypes.c_void_p), + phoneme_masks.ctypes.data_as(ctypes.c_void_p), + char_ids.ctypes.data_as(ctypes.c_void_p), + position_ids.ctypes.data_as(ctypes.c_void_p), + batch_size, + seq_len, + probs.ctypes.data_as(ctypes.c_void_p), + ) + if int(status) != 0: + raise G2PWCudaError(f"g2pw-cu inference failed: {self._last_error()}") + return probs + + def _can_append_task(self, tasks: List[G2PWBatchTask], candidate: G2PWBatchTask) -> bool: + request_count = len(tasks) + 1 + if request_count > self.batch_max_requests: + return False + total_rows = sum(int(item.model_input["char_ids"].shape[0]) for item in tasks) + int( + candidate.model_input["char_ids"].shape[0] + ) + if total_rows > self.batch_max_rows: + return False + total_tokens = sum( + int(item.model_input["char_ids"].shape[0]) * int(item.model_input["input_ids"].shape[1]) for item in tasks + ) + int(candidate.model_input["char_ids"].shape[0]) * int(candidate.model_input["input_ids"].shape[1]) + return total_tokens <= self.batch_max_tokens + + def _merge_batch_inputs(self, tasks: List[G2PWBatchTask]) -> Tuple[Dict[str, np.ndarray], List[Tuple[int, int]]]: + normalized_inputs = [self._normalize_model_input(task.model_input) for task in tasks] + total_rows = sum(int(item["char_ids"].shape[0]) for item in normalized_inputs) + max_seq_len = max(int(item["input_ids"].shape[1]) for item in normalized_inputs) + input_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64) + token_type_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64) + attention_masks = np.zeros((total_rows, max_seq_len), dtype=np.int64) + phoneme_masks = np.zeros((total_rows, normalized_inputs[0]["phoneme_masks"].shape[1]), dtype=np.float32) + char_ids = np.zeros((total_rows,), dtype=np.int64) + position_ids = np.zeros((total_rows,), dtype=np.int64) + slices: List[Tuple[int, int]] = [] + cursor = 0 + for item in normalized_inputs: + rows = int(item["char_ids"].shape[0]) + seq_len = int(item["input_ids"].shape[1]) + next_cursor = cursor + rows + input_ids[cursor:next_cursor, :seq_len] = item["input_ids"] + token_type_ids[cursor:next_cursor, :seq_len] = item["token_type_ids"] + attention_masks[cursor:next_cursor, :seq_len] = item["attention_masks"] + phoneme_masks[cursor:next_cursor] = item["phoneme_masks"] + char_ids[cursor:next_cursor] = item["char_ids"] + position_ids[cursor:next_cursor] = item["position_ids"] + slices.append((cursor, next_cursor)) + cursor = next_cursor + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_masks": attention_masks, + "phoneme_masks": phoneme_masks, + "char_ids": char_ids, + "position_ids": position_ids, + }, slices + + def _finish_task( + self, + task: G2PWBatchTask, + output: np.ndarray | None = None, + profile: Dict[str, float] | None = None, + error: Exception | None = None, + ) -> None: + task.output = output + task.profile = dict(profile or {}) + task.error = error + task.done_event.set() + + def _batch_loop(self) -> None: + while True: + with self.batch_condition: + while not self.pending_tasks and not self.closed: + self.batch_condition.wait() + if self.closed and not self.pending_tasks: + return + first_task = self.pending_tasks.popleft() + batch_tasks = [first_task] + collect_started = time.perf_counter() + deadline = collect_started + self.batch_window_s + while True: + if len(batch_tasks) >= self.batch_max_requests: + break + remaining = deadline - time.perf_counter() + if remaining <= 0.0: + break + if not self.pending_tasks: + self.batch_condition.wait(timeout=remaining) + continue + candidate = self.pending_tasks[0] + if not self._can_append_task(batch_tasks, candidate): + break + batch_tasks.append(self.pending_tasks.popleft()) + collect_wait_ms = max(0.0, (time.perf_counter() - collect_started) * 1000.0) + + now = time.perf_counter() + queue_wait_values = [max(0.0, (now - task.enqueued_at) * 1000.0) for task in batch_tasks] + try: + merged_input, row_slices = self._merge_batch_inputs(batch_tasks) + run_started = time.perf_counter() + merged_output = self._run_direct(merged_input) + run_ms = max(0.0, (time.perf_counter() - run_started) * 1000.0) + for task, (start, end) in zip(batch_tasks, row_slices): + task_rows = int(task.model_input["char_ids"].shape[0]) + task_seq_len = int(task.model_input["input_ids"].shape[1]) + self._finish_task( + task, + output=np.ascontiguousarray(merged_output[start:end]), + profile={ + "g2pw_runtime_queue_wait_ms": float(max(0.0, (run_started - task.enqueued_at) * 1000.0)), + "g2pw_runtime_collect_wait_ms": float(collect_wait_ms), + "g2pw_runtime_run_ms": float(run_ms), + "g2pw_runtime_batch_rows": float(sum(int(item.model_input["char_ids"].shape[0]) for item in batch_tasks)), + "g2pw_runtime_batch_requests": float(len(batch_tasks)), + "g2pw_runtime_task_rows": float(task_rows), + "g2pw_runtime_task_seq_len": float(task_seq_len), + "g2pw_runtime_shard_index": float(self.shard_index), + }, + ) + except Exception as exc: + run_ms = 0.0 + for task in batch_tasks: + self._finish_task(task, error=exc) + finally: + with self.batch_condition: + self.batch_total_batches += 1 + self.batch_total_tasks += len(batch_tasks) + self.batch_total_rows += sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks) + self.batch_total_queue_wait_ms += float(sum(queue_wait_values)) + self.batch_queue_wait_peak_ms = max(self.batch_queue_wait_peak_ms, max(queue_wait_values or [0.0])) + self.batch_total_collect_wait_ms += float(collect_wait_ms) * float(len(batch_tasks)) + self.batch_collect_wait_peak_ms = max(self.batch_collect_wait_peak_ms, float(collect_wait_ms)) + self.batch_total_run_ms += float(run_ms) + self.batch_run_peak_ms = max(self.batch_run_peak_ms, float(run_ms)) + self.batch_rows_peak = max( + self.batch_rows_peak, sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks) + ) + self.batch_requests_peak = max(self.batch_requests_peak, len(batch_tasks)) + + def _submit_batched(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + task = G2PWBatchTask(model_input=model_input) + with self.batch_condition: + if self.closed: + raise G2PWCudaError("g2pw-cu batch worker already closed") + task.enqueued_at = time.perf_counter() + self.pending_tasks.append(task) + self.batch_pending_peak = max(self.batch_pending_peak, len(self.pending_tasks)) + self.batch_condition.notify_all() + task.done_event.wait() + if task.error is not None: + raise task.error + assert task.output is not None + return task.output, dict(task.profile) + + def snapshot(self) -> Dict[str, float | int | bool]: + with self.batch_condition: + average_tasks_per_batch = ( + float(self.batch_total_tasks) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0 + ) + average_rows_per_batch = ( + float(self.batch_total_rows) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0 + ) + average_queue_wait_ms = ( + float(self.batch_total_queue_wait_ms) / float(self.batch_total_tasks) if self.batch_total_tasks > 0 else 0.0 + ) + average_collect_wait_ms = ( + float(self.batch_total_collect_wait_ms) / float(self.batch_total_tasks) + if self.batch_total_tasks > 0 + else 0.0 + ) + return { + "shard_index": int(self.shard_index), + "enabled": bool(self.batch_enabled), + "enable_cuda_graph": bool(self.enable_cuda_graph), + "enable_profiling": bool(self.enable_profiling), + "full_graph_cache_limit": int(self.full_graph_cache_limit), + "tail_graph_cache_limit": int(self.tail_graph_cache_limit), + "window_ms": float(self.batch_window_s * 1000.0), + "max_requests": int(self.batch_max_requests), + "max_rows": int(self.batch_max_rows), + "max_tokens": int(self.batch_max_tokens), + "pending": int(len(self.pending_tasks)), + "pending_peak": int(self.batch_pending_peak), + "total_batches": int(self.batch_total_batches), + "total_tasks": int(self.batch_total_tasks), + "total_rows": int(self.batch_total_rows), + "avg_tasks_per_batch": float(average_tasks_per_batch), + "avg_rows_per_batch": float(average_rows_per_batch), + "avg_queue_wait_ms": float(average_queue_wait_ms), + "queue_wait_peak_ms": float(self.batch_queue_wait_peak_ms), + "avg_collect_wait_ms": float(average_collect_wait_ms), + "collect_wait_peak_ms": float(self.batch_collect_wait_peak_ms), + "run_total_ms": float(self.batch_total_run_ms), + "run_peak_ms": float(self.batch_run_peak_ms), + "batch_rows_peak": int(self.batch_rows_peak), + "batch_requests_peak": int(self.batch_requests_peak), + } + + def pending_rows(self) -> int: + with self.batch_condition: + return int(sum(int(task.model_input["char_ids"].shape[0]) for task in self.pending_tasks)) + + def pending_count(self) -> int: + with self.batch_condition: + return int(len(self.pending_tasks)) + + def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + if not self.batch_enabled: + started = time.perf_counter() + output = self._run_direct(model_input) + return output, { + "g2pw_runtime_queue_wait_ms": 0.0, + "g2pw_runtime_collect_wait_ms": 0.0, + "g2pw_runtime_run_ms": float((time.perf_counter() - started) * 1000.0), + "g2pw_runtime_batch_rows": float(model_input["char_ids"].shape[0]), + "g2pw_runtime_batch_requests": 1.0, + "g2pw_runtime_task_rows": float(model_input["char_ids"].shape[0]), + "g2pw_runtime_task_seq_len": float(model_input["input_ids"].shape[1]), + "g2pw_runtime_shard_index": float(self.shard_index), + } + return self._submit_batched(model_input) + + def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + output, _profile = self.run_with_profile(model_input) + return output + + +class G2PWRuntimePool: + def __init__(self) -> None: + self.worker_count = max(1, _env_int("GPTSOVITS_G2PW_CUDA_WORKERS", 2)) + self.shards = [G2PWRuntimeWrapper(shard_index=index) for index in range(self.worker_count)] + self.lock = threading.Lock() + + def _pick_shard(self) -> G2PWRuntimeWrapper: + with self.lock: + return min( + self.shards, + key=lambda shard: ( + shard.pending_rows(), + shard.pending_count(), + shard.snapshot().get("avg_queue_wait_ms", 0.0), + ), + ) + + def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]: + shard = self._pick_shard() + output, profile = shard.run_with_profile(model_input) + profile["g2pw_runtime_pool_workers"] = float(self.worker_count) + return output, profile + + def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray: + output, _profile = self.run_with_profile(model_input) + return output + + def snapshot(self) -> Dict[str, float | int | bool | List[Dict[str, float | int | bool]]]: + shard_snapshots = [dict(shard.snapshot()) for shard in self.shards] + avg_queue_wait_ms = 0.0 + total_tasks = 0.0 + pending = 0 + pending_peak = 0 + total_batches = 0 + total_rows = 0 + batch_rows_peak = 0 + batch_requests_peak = 0 + for snapshot in shard_snapshots: + tasks = float(snapshot.get("total_tasks", 0.0)) + avg_queue_wait_ms += float(snapshot.get("avg_queue_wait_ms", 0.0)) * tasks + total_tasks += tasks + pending += int(snapshot.get("pending", 0)) + pending_peak = max(pending_peak, int(snapshot.get("pending_peak", 0))) + total_batches += int(snapshot.get("total_batches", 0)) + total_rows += int(snapshot.get("total_rows", 0)) + batch_rows_peak = max(batch_rows_peak, int(snapshot.get("batch_rows_peak", 0))) + batch_requests_peak = max(batch_requests_peak, int(snapshot.get("batch_requests_peak", 0))) + return { + "worker_count": int(self.worker_count), + "pending": int(pending), + "pending_peak": int(pending_peak), + "total_batches": int(total_batches), + "total_tasks": int(total_tasks), + "total_rows": int(total_rows), + "avg_queue_wait_ms": float(avg_queue_wait_ms / total_tasks) if total_tasks > 0 else 0.0, + "batch_rows_peak": int(batch_rows_peak), + "batch_requests_peak": int(batch_requests_peak), + "shards": shard_snapshots, + } + + +class G2PWCudaConverter(_G2PWBaseOnnxConverter): + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + super().__init__( + model_dir=model_dir, + style=style, + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + self.runtime = G2PWRuntimePool() + self.backend = "cuda" + primary_runtime = self.runtime.shards[0] + self.device = f"cuda:{primary_runtime.device_ordinal}" + self.checkpoint_path = str(primary_runtime.weights_path) + self.providers = ["g2pw-cu"] + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + probs = self.runtime.run(model_input) + preds = np.argmax(probs, axis=1).tolist() + confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist() + return [self.labels[pred] for pred in preds], confidences + + def _predict_with_profile(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float], Dict[str, float]]: + started = time.perf_counter() + probs, runtime_profile = self.runtime.run_with_profile(model_input) + preds = np.argmax(probs, axis=1).tolist() + confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist() + profile = dict(runtime_profile) + profile["g2pw_runtime_total_ms"] = float((time.perf_counter() - started) * 1000.0) + profile["g2pw_predict_ms"] = float(profile["g2pw_runtime_total_ms"]) + return [self.labels[pred] for pred in preds], confidences, profile + + def snapshot(self) -> Dict[str, float | int | bool]: + return dict(self.runtime.snapshot()) diff --git a/GPT_SoVITS/text/g2pw/dataset.py b/GPT_SoVITS/text/g2pw/dataset.py index ff09cbc2..e464c29a 100644 --- a/GPT_SoVITS/text/g2pw/dataset.py +++ b/GPT_SoVITS/text/g2pw/dataset.py @@ -18,6 +18,7 @@ Credits from typing import Dict from typing import List +from typing import Optional from typing import Tuple import numpy as np @@ -37,6 +38,8 @@ def prepare_onnx_input( use_mask: bool = False, window_size: int = None, max_len: int = 512, + char2id: Optional[Dict[str, int]] = None, + char_phoneme_masks: Optional[Dict[str, List[int]]] = None, ) -> Dict[str, np.array]: if window_size is not None: truncated_texts, truncated_query_ids = _truncate_texts( @@ -48,33 +51,88 @@ def prepare_onnx_input( phoneme_masks = [] char_ids = [] position_ids = [] + tokenized_cache = {} + + if char2id is None: + char2id = {char: idx for idx, char in enumerate(chars)} + if use_mask: + if char_phoneme_masks is None: + char_phoneme_masks = { + char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))] + for char in char2phonemes + } + else: + full_phoneme_mask = [1] * len(labels) for idx in range(len(texts)): text = (truncated_texts if window_size else texts)[idx].lower() query_id = (truncated_query_ids if window_size else query_ids)[idx] - try: - tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) - except Exception: - print(f'warning: text "{text}" is invalid') - return {} + cached = tokenized_cache.get(text) + if cached is None: + try: + tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} - text, query_id, tokens, text2token, token2text = _truncate( - max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text - ) + if len(tokens) <= max_len - 2: + processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) + cached = { + "is_short": True, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + "input_id": shared_input_id, + "token_type_id": shared_token_type_id, + "attention_mask": shared_attention_mask, + } + else: + cached = { + "is_short": False, + "tokens": tokens, + "text2token": text2token, + "token2text": token2text, + } + tokenized_cache[text] = cached - processed_tokens = ["[CLS]"] + tokens + ["[SEP]"] + if cached["is_short"]: + text_for_query = text + query_id_for_query = query_id + text2token_for_query = cached["text2token"] + input_id = cached["input_id"] + token_type_id = cached["token_type_id"] + attention_mask = cached["attention_mask"] + else: + ( + text_for_query, + query_id_for_query, + tokens_for_query, + text2token_for_query, + _token2text_for_query, + ) = _truncate( + max_len=max_len, + text=text, + query_id=query_id, + tokens=cached["tokens"], + text2token=cached["text2token"], + token2text=cached["token2text"], + ) + processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"] + input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) - token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) - attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - - query_char = text[query_id] - phoneme_mask = ( - [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels) - ) - char_id = chars.index(query_char) - position_id = text2token[query_id] + 1 # [CLS] token locate at first place + query_char = text_for_query[query_id_for_query] + if use_mask: + phoneme_mask = char_phoneme_masks[query_char] + else: + phoneme_mask = full_phoneme_mask + char_id = char2id[query_char] + position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place input_ids.append(input_id) token_type_ids.append(token_type_id) @@ -83,10 +141,15 @@ def prepare_onnx_input( char_ids.append(char_id) position_ids.append(position_id) + max_token_length = max(len(seq) for seq in input_ids) + + def _pad_sequences(sequences, pad_value=0): + return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences] + outputs = { - "input_ids": np.array(input_ids).astype(np.int64), - "token_type_ids": np.array(token_type_ids).astype(np.int64), - "attention_masks": np.array(attention_masks).astype(np.int64), + "input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64), + "token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64), + "attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64), "phoneme_masks": np.array(phoneme_masks).astype(np.float32), "char_ids": np.array(char_ids).astype(np.int64), "position_ids": np.array(position_ids).astype(np.int64), diff --git a/GPT_SoVITS/text/g2pw/g2pw.py b/GPT_SoVITS/text/g2pw/g2pw.py index 08525e91..ccd05a1b 100644 --- a/GPT_SoVITS/text/g2pw/g2pw.py +++ b/GPT_SoVITS/text/g2pw/g2pw.py @@ -8,6 +8,7 @@ from pypinyin.core import Pinyin, Style from pypinyin.seg.simpleseg import simple_seg from pypinyin.converter import UltimateConverter from pypinyin.contrib.tone_convert import to_tone +from .cuda_api import G2PWCudaConverter from .onnx_api import G2PWOnnxConverter current_file_path = os.path.dirname(__file__) @@ -27,12 +28,36 @@ class G2PWPinyin(Pinyin): tone_sandhi=False, **kwargs, ): - self._g2pw = G2PWOnnxConverter( - model_dir=model_dir, - style="pinyin", - model_source=model_source, - enable_non_tradional_chinese=enable_non_tradional_chinese, - ) + backend = os.environ.get("GPTSOVITS_G2PW_BACKEND", "cuda").strip().lower() + last_error = None + self._g2pw = None + if backend in {"cuda", "auto"}: + try: + self._g2pw = G2PWCudaConverter( + model_dir=model_dir, + style="pinyin", + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + except Exception as exc: + last_error = exc + strict_mode = os.environ.get("GPTSOVITS_G2PW_CUDA_STRICT", "0").strip().lower() in { + "1", + "true", + "yes", + "on", + } + if backend == "cuda" and strict_mode: + raise + if self._g2pw is None: + self._g2pw = G2PWOnnxConverter( + model_dir=model_dir, + style="pinyin", + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + if last_error is not None: + print(f"[g2pw] cuda backend unavailable, fallback to onnx: {last_error}") self._converter = Converter( self._g2pw, v_to_u=v_to_u, diff --git a/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp b/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp new file mode 100644 index 00000000..dc8f29a8 --- /dev/null +++ b/GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp @@ -0,0 +1,183 @@ +#include +#include +#include +#include + +#include "g2pw/runtime.h" + +namespace { + +struct G2PWRuntimeHandle { + std::unique_ptr runtime; + std::string last_error; + int num_labels = 0; +}; + +void SetError(G2PWRuntimeHandle* handle, const g2pw::Status& status) { + if (handle == nullptr) { + return; + } + handle->last_error = status.message; +} + +g2pw::RuntimeConfig BuildConfig( + int device_ordinal, + int max_batch_size, + int max_seq_len, + int full_graph_cache_limit, + int tail_graph_cache_limit, + int allow_tensor_cores, + int use_cublaslt_bias_epilogue, + int enable_profiling, + int enable_cuda_graph, + int dump_graph_cache_stats, + int gemm_precision) { + g2pw::RuntimeConfig config{}; + config.device_ordinal = device_ordinal; + config.max_batch_size = max_batch_size; + config.max_seq_len = max_seq_len; + config.full_graph_cache_limit = full_graph_cache_limit; + config.tail_graph_cache_limit = tail_graph_cache_limit; + config.allow_tensor_cores = allow_tensor_cores != 0; + config.use_cublaslt_bias_epilogue = use_cublaslt_bias_epilogue != 0; + config.enable_profiling = enable_profiling != 0; + config.enable_cuda_graph = enable_cuda_graph != 0; + config.dump_graph_cache_stats = dump_graph_cache_stats != 0; + switch (gemm_precision) { + case 1: + config.gemm_precision = g2pw::GemmPrecision::kFp16; + break; + case 2: + config.gemm_precision = g2pw::GemmPrecision::kBf16; + break; + default: + config.gemm_precision = g2pw::GemmPrecision::kFp32; + break; + } + return config; +} + +} // namespace + +extern "C" { + +void* g2pw_runtime_create( + const char* manifest_path, + const char* binary_path, + int device_ordinal, + int max_batch_size, + int max_seq_len, + int full_graph_cache_limit, + int tail_graph_cache_limit, + int allow_tensor_cores, + int use_cublaslt_bias_epilogue, + int enable_profiling, + int enable_cuda_graph, + int dump_graph_cache_stats, + int gemm_precision) { + auto* handle = new G2PWRuntimeHandle(); + try { + if (manifest_path == nullptr || binary_path == nullptr) { + handle->last_error = "manifest_path and binary_path must be non-null"; + return handle; + } + g2pw::RuntimeConfig config = BuildConfig( + device_ordinal, + max_batch_size, + max_seq_len, + full_graph_cache_limit, + tail_graph_cache_limit, + allow_tensor_cores, + use_cublaslt_bias_epilogue, + enable_profiling, + enable_cuda_graph, + dump_graph_cache_stats, + gemm_precision); + g2pw::Status status = g2pw::Runtime::Create( + config, + std::string(manifest_path), + std::string(binary_path), + &handle->runtime); + if (!status.ok()) { + SetError(handle, status); + return handle; + } + handle->num_labels = handle->runtime != nullptr ? handle->runtime->weights().manifest().num_labels : 0; + handle->last_error.clear(); + return handle; + } catch (const std::exception& exc) { + handle->last_error = exc.what(); + return handle; + } catch (...) { + handle->last_error = "unknown exception"; + return handle; + } +} + +void g2pw_runtime_destroy(void* raw_handle) { + auto* handle = static_cast(raw_handle); + delete handle; +} + +const char* g2pw_runtime_last_error(void* raw_handle) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr) { + return "invalid runtime handle"; + } + return handle->last_error.c_str(); +} + +int g2pw_runtime_num_labels(void* raw_handle) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr || handle->runtime == nullptr) { + return 0; + } + return handle->num_labels; +} + +int g2pw_runtime_run( + void* raw_handle, + const std::int64_t* input_ids, + const std::int64_t* token_type_ids, + const std::int64_t* attention_mask, + const float* phoneme_mask, + const std::int64_t* char_ids, + const std::int64_t* position_ids, + std::int32_t batch_size, + std::int32_t seq_len, + float* probs) { + auto* handle = static_cast(raw_handle); + if (handle == nullptr || handle->runtime == nullptr) { + return static_cast(g2pw::StatusCode::kInvalidArgument); + } + try { + g2pw::InferenceInputs inputs{}; + inputs.input_ids = input_ids; + inputs.token_type_ids = token_type_ids; + inputs.attention_mask = attention_mask; + inputs.phoneme_mask = phoneme_mask; + inputs.char_ids = char_ids; + inputs.position_ids = position_ids; + inputs.batch_size = batch_size; + inputs.seq_len = seq_len; + + g2pw::InferenceOutputs outputs{}; + outputs.probs = probs; + + const g2pw::Status status = handle->runtime->Run(inputs, outputs); + if (!status.ok()) { + SetError(handle, status); + return static_cast(status.code); + } + handle->last_error.clear(); + return static_cast(g2pw::StatusCode::kOk); + } catch (const std::exception& exc) { + handle->last_error = exc.what(); + return static_cast(g2pw::StatusCode::kInternalError); + } catch (...) { + handle->last_error = "unknown exception"; + return static_cast(g2pw::StatusCode::kInternalError); + } +} + +} diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index 1d5e4231..f6d7fab7 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -3,6 +3,7 @@ import json import os +import time import warnings import zipfile from typing import Any, Dict, List, Tuple @@ -10,7 +11,6 @@ from typing import Any, Dict, List, Tuple import numpy as np import onnxruntime import requests -import torch from opencc import OpenCC from pypinyin import Style, pinyin from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -22,9 +22,8 @@ from .utils import load_config onnxruntime.set_default_logger_severity(3) try: onnxruntime.preload_dlls() -except: +except Exception: pass - # traceback.print_exc() warnings.filterwarnings("ignore") model_version = "1.1" @@ -55,6 +54,41 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis return all_preds, all_confidences +def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]: + for candidate_dir in candidate_dirs: + if not candidate_dir: + continue + json_path = os.path.join(candidate_dir, filename) + if os.path.exists(json_path): + with open(json_path, "r", encoding="utf-8") as fr: + return json.load(fr) + raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}") + + +def _find_first_existing_file(*paths: str) -> str: + for path in paths: + if path and os.path.exists(path): + return path + raise FileNotFoundError(f"Files not found: {paths}") + + +def _resolve_tokenizer_source(model_source: str | None) -> str: + candidate_paths = [] + if model_source: + candidate_paths.append(model_source) + repo_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) + candidate_paths.extend( + [ + os.path.join(repo_root, "pretrained_models", "g2pw-chinese"), + os.path.join(repo_root, "pretrained_models", "chinese-roberta-wwm-ext-large"), + ] + ) + for candidate in candidate_paths: + if candidate and os.path.exists(candidate): + return candidate + return model_source or "bert-base-chinese" + + def download_and_decompress(model_dir: str = "G2PWModel/"): if not os.path.exists(model_dir): parent_directory = os.path.dirname(model_dir) @@ -62,7 +96,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): extract_dir = os.path.join(parent_directory, "G2PWModel_1.1") extract_dir_new = os.path.join(parent_directory, "G2PWModel") print("Downloading g2pw model...") - modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip" + modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" with requests.get(modelscope_url, stream=True) as r: r.raise_for_status() with open(zip_dir, "wb") as f: @@ -79,7 +113,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"): return model_dir -class G2PWOnnxConverter: +class _G2PWBaseOnnxConverter: def __init__( self, model_dir: str = "G2PWModel/", @@ -87,33 +121,16 @@ class G2PWOnnxConverter: model_source: str = None, enable_non_tradional_chinese: bool = False, ): - uncompress_path = download_and_decompress(model_dir) + self.model_dir = download_and_decompress(model_dir) + self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True) - sess_options = onnxruntime.SessionOptions() - sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL - sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0 - if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], - ) - else: - self.session_g2pW = onnxruntime.InferenceSession( - os.path.join(uncompress_path, "g2pW.onnx"), - sess_options=sess_options, - providers=["CPUExecutionProvider"], - ) - self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True) - - self.model_source = model_source if model_source else self.config.model_source + self.model_source = _resolve_tokenizer_source(model_source if model_source else self.config.model_source) self.enable_opencc = enable_non_tradional_chinese + self.tokenizer = AutoTokenizer.from_pretrained(self.model_source, local_files_only=True) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) + polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt") + monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt") - polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt") - monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt") self.polyphonic_chars = [ line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n") ] @@ -149,31 +166,47 @@ class G2PWOnnxConverter: ) self.chars = sorted(list(self.char2phonemes.keys())) + self.char2id = {char: idx for idx, char in enumerate(self.chars)} + self.char_phoneme_masks = ( + { + char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))] + for char in self.char2phonemes + } + if self.config.use_mask + else None + ) self.polyphonic_chars_new = set(self.chars) for char in self.non_polyphonic: - if char in self.polyphonic_chars_new: - self.polyphonic_chars_new.remove(char) + self.polyphonic_chars_new.discard(char) self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars} for char in self.non_monophonic: - if char in self.monophonic_chars_dict: - self.monophonic_chars_dict.pop(char) + self.monophonic_chars_dict.pop(char, None) - self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"] + default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel")) + candidate_asset_dirs = [self.model_dir, default_asset_dir] + self.bopomofo_convert_dict = _load_json_from_candidates( + "bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs + ) + self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs) - with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr: - self.bopomofo_convert_dict = json.load(fr) self.style_convert_func = { "bopomofo": lambda x: x, "pinyin": self._convert_bopomofo_to_pinyin, }[style] - with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr: - self.char_bopomofo_dict = json.load(fr) - if self.enable_opencc: self.cc = OpenCC("s2tw") + self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in { + "1", + "true", + "yes", + "y", + "on", + } + # 聚焦到多音字附近上下文,默认左右各16字;设为0表示关闭裁剪(整句)。 + self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16"))) def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] @@ -181,11 +214,14 @@ class G2PWOnnxConverter: component = self.bopomofo_convert_dict.get(bopomofo[:-1]) if component: return component + tone - else: - print(f'Warning: "{bopomofo}" cannot convert to pinyin') - return None + print(f'Warning: "{bopomofo}" cannot convert to pinyin') + return None def __call__(self, sentences: List[str]) -> List[List[str]]: + results, _profile = self.predict_sentences_with_profile(sentences) + return results + + def predict_sentences_with_profile(self, sentences: List[str]) -> Tuple[List[List[str]], Dict[str, float]]: if isinstance(sentences, str): sentences = [sentences] @@ -197,51 +233,202 @@ class G2PWOnnxConverter: translated_sentences.append(translated_sent) sentences = translated_sentences - texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) + texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences) if len(texts) == 0: - # sentences no polyphonic words - return partial_results + return partial_results, {} - onnx_input = prepare_onnx_input( + model_input = prepare_onnx_input( tokenizer=self.tokenizer, labels=self.labels, char2phonemes=self.char2phonemes, chars=self.chars, texts=texts, - query_ids=query_ids, + query_ids=model_query_ids, use_mask=self.config.use_mask, window_size=None, + char2id=self.char2id, + char_phoneme_masks=self.char_phoneme_masks, ) - preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels) + if not model_input: + return partial_results, {} + + predict_profile: Dict[str, float] = {} + if self.enable_sentence_dedup: + preds, _confidences, predict_profile = self._predict_with_sentence_dedup_profiled( + model_input=model_input, + texts=texts, + ) + else: + if hasattr(self, "_predict_with_profile"): + preds, _confidences, predict_profile = self._predict_with_profile(model_input=model_input) + else: + predict_started = time.perf_counter() + preds, _confidences = self._predict(model_input=model_input) + predict_profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_started) * 1000.0) + if self.config.use_char_phoneme: preds = [pred.split(" ")[1] for pred in preds] results = partial_results - for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): + for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds): results[sent_id][query_id] = self.style_convert_func(pred) - return results + return results, predict_profile - def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]: - texts, query_ids, sent_ids, partial_results = [], [], [], [] + def _prepare_data( + self, sentences: List[str] + ) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]: + texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], [] for sent_id, sent in enumerate(sentences): - # pypinyin works well for Simplified Chinese than Traditional Chinese sent_s = tranditional_to_simplified(sent) pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3) partial_result = [None] * len(sent) + polyphonic_indices: List[int] = [] for i, char in enumerate(sent): if char in self.polyphonic_chars_new: - texts.append(sent) - query_ids.append(i) - sent_ids.append(sent_id) + polyphonic_indices.append(i) elif char in self.monophonic_chars_dict: partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: partial_result[i] = pypinyin_result[i][0] - # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) else: partial_result[i] = pypinyin_result[i][0] + if polyphonic_indices: + if self.polyphonic_context_chars > 0: + left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars) + right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1) + sent_for_predict = sent[left:right] + query_offset = left + else: + sent_for_predict = sent + query_offset = 0 + + for index in polyphonic_indices: + texts.append(sent_for_predict) + model_query_ids.append(index - query_offset) + result_query_ids.append(index) + sent_ids.append(sent_id) + partial_results.append(partial_result) - return texts, query_ids, sent_ids, partial_results + return texts, model_query_ids, result_query_ids, sent_ids, partial_results + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + raise NotImplementedError + + def _predict_with_sentence_dedup( + self, model_input: Dict[str, Any], texts: List[str] + ) -> Tuple[List[str], List[float]]: + if len(texts) <= 1: + return self._predict(model_input=model_input) + + grouped_indices: Dict[str, List[int]] = {} + for idx, text in enumerate(texts): + grouped_indices.setdefault(text, []).append(idx) + + if all(len(indices) == 1 for indices in grouped_indices.values()): + return self._predict(model_input=model_input) + + preds: List[str] = [""] * len(texts) + confidences: List[float] = [0.0] * len(texts) + for indices in grouped_indices.values(): + group_input = {name: value[indices] for name, value in model_input.items()} + if len(indices) > 1: + for name in ("input_ids", "token_type_ids", "attention_masks"): + group_input[name] = group_input[name][:1] + + group_preds, group_confidences = self._predict(model_input=group_input) + for output_idx, pred, confidence in zip(indices, group_preds, group_confidences): + preds[output_idx] = pred + confidences[output_idx] = confidence + + return preds, confidences + + def _predict_with_sentence_dedup_profiled( + self, + model_input: Dict[str, Any], + texts: List[str], + ) -> Tuple[List[str], List[float], Dict[str, float]]: + if len(texts) <= 1: + if hasattr(self, "_predict_with_profile"): + return self._predict_with_profile(model_input=model_input) + predict_started = time.perf_counter() + preds, confidences = self._predict(model_input=model_input) + return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)} + + grouped_indices: Dict[str, List[int]] = {} + for idx, text in enumerate(texts): + grouped_indices.setdefault(text, []).append(idx) + + if all(len(indices) == 1 for indices in grouped_indices.values()): + if hasattr(self, "_predict_with_profile"): + return self._predict_with_profile(model_input=model_input) + predict_started = time.perf_counter() + preds, confidences = self._predict(model_input=model_input) + return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)} + + preds: List[str] = [""] * len(texts) + confidences: List[float] = [0.0] * len(texts) + merged_profile: Dict[str, float] = {} + for indices in grouped_indices.values(): + group_input = {name: value[indices] for name, value in model_input.items()} + if len(indices) > 1: + for name in ("input_ids", "token_type_ids", "attention_masks"): + group_input[name] = group_input[name][:1] + if hasattr(self, "_predict_with_profile"): + group_preds, group_confidences, group_profile = self._predict_with_profile(model_input=group_input) + for key, value in dict(group_profile or {}).items(): + merged_profile[key] = float(merged_profile.get(key, 0.0)) + float(value) + else: + predict_started = time.perf_counter() + group_preds, group_confidences = self._predict(model_input=group_input) + merged_profile["g2pw_predict_ms"] = float( + merged_profile.get("g2pw_predict_ms", 0.0) + (time.perf_counter() - predict_started) * 1000.0 + ) + for output_idx, pred, confidence in zip(indices, group_preds, group_confidences): + preds[output_idx] = pred + confidences[output_idx] = confidence + return preds, confidences, merged_profile + + +class G2PWOnnxConverter(_G2PWBaseOnnxConverter): + def __init__( + self, + model_dir: str = "G2PWModel/", + style: str = "bopomofo", + model_source: str = None, + enable_non_tradional_chinese: bool = False, + ): + super().__init__( + model_dir=model_dir, + style=style, + model_source=model_source, + enable_non_tradional_chinese=enable_non_tradional_chinese, + ) + + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL + sess_options.intra_op_num_threads = 2 + + onnx_path = _find_first_existing_file( + os.path.join(self.model_dir, "g2pW.onnx"), + os.path.join(self.model_dir, "g2pw.onnx"), + ) + + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + else: + self.session_g2pw = onnxruntime.InferenceSession( + onnx_path, + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) + + def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]: + return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels) diff --git a/api_v2.py b/api_v2.py index 21511db3..9c29989f 100644 --- a/api_v2.py +++ b/api_v2.py @@ -39,8 +39,8 @@ POST: "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. - "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。 + "super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。 "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) @@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights` GET: ``` -http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt ``` RESP: 成功: 返回"success", http code 200 @@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights` GET: ``` -http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth ``` RESP: @@ -104,27 +104,22 @@ RESP: import os import sys import traceback -from typing import Generator, Union +from typing import Union now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import argparse -import subprocess -import wave import signal -import numpy as np -import soundfile as sf from fastapi import FastAPI, Response from fastapi.responses import StreamingResponse, JSONResponse import uvicorn -from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine from pydantic import BaseModel -import threading # print(sys.path) i18n = I18nAuto() @@ -147,6 +142,14 @@ if config_path in [None, ""]: tts_config = TTS_Config(config_path) print(tts_config) tts_pipeline = TTS(tts_config) +tts_engine = UnifiedTTSEngine( + tts_pipeline, + cut_method_names=cut_method_names, + control_callbacks=RuntimeControlCallbacks( + restart=lambda: os.execl(sys.executable, sys.executable, *argv), + exit=lambda: os.kill(os.getpid(), signal.SIGTERM), + ), +) APP = FastAPI() @@ -178,168 +181,8 @@ class TTS_Request(BaseModel): min_chunk_length: int = 16 -def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): - # Author: AkagawaTsurunaki - # Issue: - # Stack overflow probabilistically occurs - # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called - # using the Python library `soundfile` - # Note: - # This is an issue related to `libsndfile`, not this project itself. - # It happens when you generate a large audio tensor (about 499804 frames in my PC) - # and try to convert it to an ogg file. - # Related: - # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 - # https://github.com/libsndfile/libsndfile/issues/1023 - # https://github.com/bastibe/python-soundfile/issues/396 - # Suggestion: - # Or split the whole audio data into smaller audio segment to avoid stack overflow? - - def handle_pack_ogg(): - with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) - - - - # See: https://docs.python.org/3/library/threading.html - # The stack size of this thread is at least 32768 - # If stack overflow error still occurs, just modify the `stack_size`. - # stack_size = n * 4096, where n should be a positive integer. - # Here we chose n = 4096. - stack_size = 4096 * 4096 - try: - threading.stack_size(stack_size) - pack_ogg_thread = threading.Thread(target=handle_pack_ogg) - pack_ogg_thread.start() - pack_ogg_thread.join() - except RuntimeError as e: - # If changing the thread stack size is unsupported, a RuntimeError is raised. - print("RuntimeError: {}".format(e)) - print("Changing the thread stack size is unsupported.") - except ValueError as e: - # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. - print("ValueError: {}".format(e)) - print("The specified stack size is invalid.") - - return io_buffer - - -def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer.write(data.tobytes()) - return io_buffer - - -def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): - io_buffer = BytesIO() - sf.write(io_buffer, data, rate, format="wav") - return io_buffer - - -def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): - process = subprocess.Popen( - [ - "ffmpeg", - "-f", - "s16le", # 输入16位有符号小端整数PCM - "-ar", - str(rate), # 设置采样率 - "-ac", - "1", # 单声道 - "-i", - "pipe:0", # 从管道读取输入 - "-c:a", - "aac", # 音频编码器为AAC - "-b:a", - "192k", # 比特率 - "-vn", # 不包含视频 - "-f", - "adts", # 输出AAC数据流格式 - "pipe:1", # 将输出写入管道 - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, _ = process.communicate(input=data.tobytes()) - io_buffer.write(out) - return io_buffer - - -def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): - if media_type == "ogg": - io_buffer = pack_ogg(io_buffer, data, rate) - elif media_type == "aac": - io_buffer = pack_aac(io_buffer, data, rate) - elif media_type == "wav": - io_buffer = pack_wav(io_buffer, data, rate) - else: - io_buffer = pack_raw(io_buffer, data, rate) - io_buffer.seek(0) - return io_buffer - - -# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py -def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): - # This will create a wave header then append the frame input - # It should be first on a streaming wav file - # Other frames better should not have it (else you will hear some artifacts each chunk start) - wav_buf = BytesIO() - with wave.open(wav_buf, "wb") as vfout: - vfout.setnchannels(channels) - vfout.setsampwidth(sample_width) - vfout.setframerate(sample_rate) - vfout.writeframes(frame_input) - - wav_buf.seek(0) - return wav_buf.read() - - -def handle_control(command: str): - if command == "restart": - os.execl(sys.executable, sys.executable, *argv) - elif command == "exit": - os.kill(os.getpid(), signal.SIGTERM) - exit(0) - - -def check_params(req: dict): - text: str = req.get("text", "") - text_lang: str = req.get("text_lang", "") - ref_audio_path: str = req.get("ref_audio_path", "") - streaming_mode: bool = req.get("streaming_mode", False) - media_type: str = req.get("media_type", "wav") - prompt_lang: str = req.get("prompt_lang", "") - text_split_method: str = req.get("text_split_method", "cut5") - - if ref_audio_path in [None, ""]: - return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) - if text in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text is required"}) - if text_lang in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text_lang is required"}) - elif text_lang.lower() not in tts_config.languages: - return JSONResponse( - status_code=400, - content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, - ) - if prompt_lang in [None, ""]: - return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) - elif prompt_lang.lower() not in tts_config.languages: - return JSONResponse( - status_code=400, - content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, - ) - if media_type not in ["wav", "raw", "ogg", "aac"]: - return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) - # elif media_type == "ogg" and not streaming_mode: - # return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) - - if text_split_method not in cut_method_names: - return JSONResponse( - status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} - ) - - return None +def _lower_or_none(value: str | None) -> str | None: + return value.lower() if isinstance(value, str) else value async def tts_handle(req: dict): @@ -368,7 +211,7 @@ async def tts_handle(req: dict): "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35, # float. repetition penalty for T2S model. "sample_steps": 32, # int. number of sampling steps for VITS model V3. - "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + "super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline. "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) @@ -377,70 +220,11 @@ async def tts_handle(req: dict): StreamingResponse: audio stream response. """ - streaming_mode = req.get("streaming_mode", False) - return_fragment = req.get("return_fragment", False) - media_type = req.get("media_type", "wav") - - check_res = check_params(req) - if check_res is not None: - return check_res - - if streaming_mode == 0: - streaming_mode = False - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 1: - streaming_mode = False - return_fragment = True - fixed_length_chunk = False - elif streaming_mode == 2: - streaming_mode = True - return_fragment = False - fixed_length_chunk = False - elif streaming_mode == 3: - streaming_mode = True - return_fragment = False - fixed_length_chunk = True - - else: - return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"}) - - req["streaming_mode"] = streaming_mode - req["return_fragment"] = return_fragment - req["fixed_length_chunk"] = fixed_length_chunk - - print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}") - - streaming_mode = streaming_mode or return_fragment - - try: - tts_generator = tts_pipeline.run(req) - - if streaming_mode: - - def streaming_generator(tts_generator: Generator, media_type: str): - if_frist_chunk = True - for sr, chunk in tts_generator: - if if_frist_chunk and media_type == "wav": - yield wave_header_chunk(sample_rate=sr) - media_type = "raw" - if_frist_chunk = False - yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() - - # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" - return StreamingResponse( - streaming_generator( - tts_generator, - media_type, - ), - media_type=f"audio/{media_type}", - ) - - else: - sr, audio_data = next(tts_generator) - audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - return Response(audio_data, media_type=f"audio/{media_type}") + result = await tts_engine.run_direct_tts_async(req) + if result.streaming: + return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") + return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") except Exception as e: return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) @@ -449,7 +233,11 @@ async def tts_handle(req: dict): async def control(command: str = None): if command is None: return JSONResponse(status_code=400, content={"message": "command is required"}) - handle_control(command) + try: + tts_engine.handle_control(command) + return JSONResponse(status_code=200, content={"message": "success"}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)}) @APP.get("/tts") @@ -481,11 +269,11 @@ async def tts_get_endpoint( ): req = { "text": text, - "text_lang": text_lang.lower(), + "text_lang": _lower_or_none(text_lang), "ref_audio_path": ref_audio_path, "aux_ref_audio_paths": aux_ref_audio_paths, "prompt_text": prompt_text, - "prompt_lang": prompt_lang.lower(), + "prompt_lang": _lower_or_none(prompt_lang), "top_k": top_k, "top_p": top_p, "temperature": temperature, @@ -517,10 +305,10 @@ async def tts_post_endpoint(request: TTS_Request): @APP.get("/set_refer_audio") async def set_refer_aduio(refer_audio_path: str = None): try: - tts_pipeline.set_ref_audio(refer_audio_path) + payload = tts_engine.set_refer_audio(refer_audio_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) # @APP.post("/set_refer_audio") @@ -545,24 +333,19 @@ async def set_refer_aduio(refer_audio_path: str = None): @APP.get("/set_gpt_weights") async def set_gpt_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) - tts_pipeline.init_t2s_weights(weights_path) + payload = tts_engine.set_gpt_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) - - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) @APP.get("/set_sovits_weights") async def set_sovits_weights(weights_path: str = None): try: - if weights_path in ["", None]: - return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) - tts_pipeline.init_vits_weights(weights_path) + payload = tts_engine.set_sovits_weights(weights_path) except Exception as e: return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) - return JSONResponse(status_code=200, content={"message": "success"}) + return JSONResponse(status_code=200, content=payload) if __name__ == "__main__": diff --git a/api_v3.py b/api_v3.py new file mode 100644 index 00000000..35ecf240 --- /dev/null +++ b/api_v3.py @@ -0,0 +1,443 @@ +""" +# WebAPI文档 + +` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` + +## 执行参数: + `-a` - `绑定地址, 默认"127.0.0.1"` + `-p` - `绑定端口, 默认9880` + `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"` + +## 调用: + +### 推理 + +endpoint: `/tts` +GET: +``` +http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true +``` + +POST: +```json +{ + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。 + "super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。 + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: +``` +http://127.0.0.1:9880/control?command=restart +``` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + + +### 切换GPT模型 + +endpoint: `/set_gpt_weights` + +GET: +``` +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt +``` +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 切换Sovits模型 + +endpoint: `/set_sovits_weights` + +GET: +``` +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth +``` + +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + +""" + +import os +import sys +import traceback +from typing import List, Union + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +from runtime_preload import preload_text_runtime_deps + +preload_text_runtime_deps() + +import argparse +import signal +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from pydantic import BaseModel + +# print(sys.path) +i18n = I18nAuto() +cut_method_names = get_cut_method_names() + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") +parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") +args = parser.parse_args() +config_path = args.tts_config +# device = args.device +port = args.port +host = args.bind_addr +argv = sys.argv + +if config_path in [None, ""]: + config_path = "GPT-SoVITS/configs/tts_infer.yaml" + +tts_config = TTS_Config(config_path) +print(tts_config) +tts_pipeline = TTS(tts_config) +tts_engine = UnifiedTTSEngine( + tts_pipeline, + cut_method_names=cut_method_names, + control_callbacks=RuntimeControlCallbacks( + restart=lambda: os.execl(sys.executable, sys.executable, *argv), + exit=lambda: os.kill(os.getpid(), signal.SIGTERM), + ), +) + +APP = FastAPI() + + +class TTS_Request(BaseModel): + text: str = None + text_lang: str = None + ref_audio_path: str = None + aux_ref_audio_paths: list = None + prompt_lang: str = None + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = True + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: Union[bool, int] = False + parallel_infer: bool = True + repetition_penalty: float = 1.35 + sample_steps: int = 32 + super_sampling: bool = False + overlap_length: int = 2 + min_chunk_length: int = 16 + + +class Scheduler_Debug_Request_Item(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + ready_step: int = 0 + + +class Scheduler_Debug_Request(BaseModel): + requests: List[Scheduler_Debug_Request_Item] + max_steps: int = 1500 + seed: int = -1 + + +class Scheduler_Submit_Request(BaseModel): + request_id: str | None = None + text: str + text_lang: str + ref_audio_path: str + prompt_lang: str + prompt_text: str = "" + top_k: int = 15 + top_p: float = 1 + temperature: float = 1 + repetition_penalty: float = 1.35 + early_stop_num: int = -1 + speed_factor: float = 1.0 + sample_steps: int = 32 + media_type: str = "wav" + timeout_sec: float = 30.0 + + +def _lower_or_none(value: str | None) -> str | None: + return value.lower() if isinstance(value, str) else value + + +async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request): + try: + result = await tts_engine.run_scheduler_debug( + request_items=[item.dict() for item in request.requests], + max_steps=int(request.max_steps), + seed=int(request.seed), + ) + return JSONResponse(status_code=200, content=result.payload) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler debug failed", "Exception": str(e)}, + ) + + +async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request): + try: + result = await tts_engine.run_scheduler_submit(request.dict()) + return Response(result.audio_bytes, media_type=result.media_type, headers=result.headers) + except Exception as e: + return JSONResponse( + status_code=400, + content={"message": "scheduler submit failed", "Exception": str(e)}, + ) + + +async def tts_handle(req: dict): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 15, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket": True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35, # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline. + "streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed ) + "overlap_length": 2, # int. overlap length of semantic tokens for streaming mode. + "min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size) + } + returns: + StreamingResponse: audio stream response. + """ + + try: + result = await tts_engine.run_direct_tts_async(req) + if result.streaming: + return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}") + return Response(result.audio_bytes, media_type=f"audio/{result.media_type}") + except Exception as e: + return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) + + +@APP.get("/control") +async def control(command: str = None): + if command is None: + return JSONResponse(status_code=400, content={"message": "command is required"}) + try: + tts_engine.handle_control(command) + return JSONResponse(status_code=200, content={"message": "success"}) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)}) + + +@APP.get("/tts") +async def tts_get_endpoint( + text: str = None, + text_lang: str = None, + ref_audio_path: str = None, + aux_ref_audio_paths: list = None, + prompt_lang: str = None, + prompt_text: str = "", + top_k: int = 15, + top_p: float = 1, + temperature: float = 1, + text_split_method: str = "cut5", + batch_size: int = 1, + batch_threshold: float = 0.75, + split_bucket: bool = True, + speed_factor: float = 1.0, + fragment_interval: float = 0.3, + seed: int = -1, + media_type: str = "wav", + parallel_infer: bool = True, + repetition_penalty: float = 1.35, + sample_steps: int = 32, + super_sampling: bool = False, + streaming_mode: Union[bool, int] = False, + overlap_length: int = 2, + min_chunk_length: int = 16, +): + req = { + "text": text, + "text_lang": _lower_or_none(text_lang), + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": aux_ref_audio_paths, + "prompt_text": prompt_text, + "prompt_lang": _lower_or_none(prompt_lang), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size": int(batch_size), + "batch_threshold": float(batch_threshold), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "fragment_interval": fragment_interval, + "seed": seed, + "media_type": media_type, + "streaming_mode": streaming_mode, + "parallel_infer": parallel_infer, + "repetition_penalty": float(repetition_penalty), + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + "overlap_length": int(overlap_length), + "min_chunk_length": int(min_chunk_length), + } + return await tts_handle(req) + + +@APP.post("/tts") +async def tts_post_endpoint(request: TTS_Request): + req = request.dict() + return await tts_handle(req) + + +@APP.post("/tts_scheduler_debug") +async def tts_scheduler_debug_endpoint(request: Scheduler_Debug_Request): + return await tts_scheduler_debug_handle(request) + + +@APP.post("/tts_scheduler_submit") +async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request): + return await tts_scheduler_submit_handle(request) + + +@APP.get("/tts_scheduler_state") +async def tts_scheduler_state_endpoint(): + return JSONResponse(status_code=200, content=tts_engine.get_runtime_state()) + + +@APP.get("/set_refer_audio") +async def set_refer_aduio(refer_audio_path: str = None): + try: + payload = tts_engine.set_refer_audio(refer_audio_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content=payload) + + +# @APP.post("/set_refer_audio") +# async def set_refer_aduio_post(audio_file: UploadFile = File(...)): +# try: +# # 检查文件类型,确保是音频文件 +# if not audio_file.content_type.startswith("audio/"): +# return JSONResponse(status_code=400, content={"message": "file type is not supported"}) + +# os.makedirs("uploaded_audio", exist_ok=True) +# save_path = os.path.join("uploaded_audio", audio_file.filename) +# # 保存音频文件到服务器上的一个目录 +# with open(save_path , "wb") as buffer: +# buffer.write(await audio_file.read()) + +# tts_pipeline.set_ref_audio(save_path) +# except Exception as e: +# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)}) +# return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_gpt_weights") +async def set_gpt_weights(weights_path: str = None): + try: + payload = tts_engine.set_gpt_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content=payload) + + +@APP.get("/set_sovits_weights") +async def set_sovits_weights(weights_path: str = None): + try: + payload = tts_engine.set_sovits_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content=payload) + + +if __name__ == "__main__": + try: + if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None + uvicorn.run(app=APP, host=host, port=port, workers=1) + except Exception: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) diff --git a/third_party/g2pw-cu b/third_party/g2pw-cu new file mode 160000 index 00000000..a53cf4ee --- /dev/null +++ b/third_party/g2pw-cu @@ -0,0 +1 @@ +Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98 diff --git a/tools/bench_api_v3_scheduler_submit.py b/tools/bench_api_v3_scheduler_submit.py new file mode 100644 index 00000000..c16468e1 --- /dev/null +++ b/tools/bench_api_v3_scheduler_submit.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import asyncio +import json +import subprocess +import threading +import time +import wave +from pathlib import Path +from typing import Any, Dict, List, Optional + +import httpx + +ROOT_DIR = Path(__file__).resolve().parents[1] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.") + parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880") + parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit") + parser.add_argument("--concurrency", type=int, required=True) + parser.add_argument("--timeout-sec", type=float, default=120.0) + parser.add_argument("--server-pid", type=int, default=None) + parser.add_argument("--poll-interval-sec", type=float, default=0.1) + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--media-type", type=str, default="wav") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--sample-steps", type=int, default=32) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench") + return parser.parse_args() + + +def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]: + wav_paths_all = sorted(args.wav_dir.glob("*.wav")) + wav_paths: List[Path] = [] + for wav_path in wav_paths_all: + with wave.open(str(wav_path), "rb") as handle: + duration = handle.getnframes() / float(handle.getframerate()) + if 3.0 <= duration <= 10.0: + wav_paths.append(wav_path) + if not wav_paths: + raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}") + text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if not text_lines: + raise ValueError(f"没有找到有效文本行: {args.text_file}") + + requests: List[Dict[str, Any]] = [] + for index in range(args.concurrency): + wav_path = wav_paths[index % len(wav_paths)] + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"缺少参考文本: {lab_path}") + requests.append( + { + "request_id": f"bench_{args.concurrency:03d}_{index:03d}", + "text": text_lines[index % len(text_lines)], + "text_lang": args.text_lang, + "ref_audio_path": str(wav_path), + "prompt_lang": args.prompt_lang, + "prompt_text": lab_path.read_text(encoding="utf-8").strip(), + "top_k": int(args.top_k), + "top_p": float(args.top_p), + "temperature": float(args.temperature), + "repetition_penalty": float(args.repetition_penalty), + "sample_steps": int(args.sample_steps), + "media_type": args.media_type, + "timeout_sec": float(args.timeout_sec), + } + ) + return requests + + +class GpuMemoryPoller: + def __init__(self, server_pid: Optional[int], interval_sec: float): + self.server_pid = server_pid + self.interval_sec = interval_sec + self._stop = threading.Event() + self.samples: List[Dict[str, Any]] = [] + self.thread: Optional[threading.Thread] = None + + def _query_memory_mb(self) -> Optional[int]: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-compute-apps=pid,used_gpu_memory", + "--format=csv,noheader,nounits", + ], + check=True, + capture_output=True, + text=True, + ) + except Exception: + return None + total = 0 + found = False + for line in result.stdout.splitlines(): + line = line.strip() + if not line: + continue + parts = [item.strip() for item in line.split(",")] + if len(parts) != 2: + continue + try: + pid = int(parts[0]) + used_mb = int(parts[1]) + except ValueError: + continue + if self.server_pid is None or pid == self.server_pid: + total += used_mb + found = True + if self.server_pid is None: + return total + return total if found else 0 + + def _run(self) -> None: + while not self._stop.is_set(): + used_mb = self._query_memory_mb() + self.samples.append({"ts": time.time(), "used_mb": used_mb}) + self._stop.wait(self.interval_sec) + + def start(self) -> None: + self.thread = threading.Thread(target=self._run, daemon=True) + self.thread.start() + + def stop(self) -> None: + self._stop.set() + if self.thread is not None: + self.thread.join(timeout=2.0) + + def summary(self) -> Dict[str, Any]: + valid = [item for item in self.samples if item["used_mb"] is not None] + peak = max(valid, key=lambda item: item["used_mb"]) if valid else None + first = valid[0] if valid else None + last = valid[-1] if valid else None + return { + "server_pid": self.server_pid, + "sample_count": int(len(self.samples)), + "start_used_mb": None if first is None else int(first["used_mb"]), + "peak_used_mb": None if peak is None else int(peak["used_mb"]), + "peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]), + "end_used_mb": None if last is None else int(last["used_mb"]), + "peak_ts": None if peak is None else float(peak["ts"]), + "samples": self.samples, + } + + +async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]: + started = time.perf_counter() + try: + response = await client.post(url, json=payload) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + item = { + "request_id": payload["request_id"], + "status_code": int(response.status_code), + "elapsed_ms": float(elapsed_ms), + "content_type": response.headers.get("content-type"), + "audio_bytes": int(len(response.content)), + "headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")}, + } + if response.status_code != 200: + try: + item["error_body"] = response.json() + except Exception: + item["error_body"] = response.text + return item + except Exception as exc: + return { + "request_id": payload["request_id"], + "status_code": -1, + "elapsed_ms": float((time.perf_counter() - started) * 1000.0), + "exception": repr(exc), + } + + +async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]: + payloads = load_requests(args) + url = args.base_url.rstrip("/") + args.endpoint + poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec) + + limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency) + timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0) + + started = time.perf_counter() + poller.start() + try: + async with httpx.AsyncClient(limits=limits, timeout=timeout) as client: + results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads]) + finally: + poller.stop() + wall_ms = (time.perf_counter() - started) * 1000.0 + + ok_results = [item for item in results if item["status_code"] == 200] + failed_results = [item for item in results if item["status_code"] != 200] + request_total_ms = [] + worker_total_ms = [] + for item in ok_results: + headers = item.get("headers", {}) + if "x-request-total-ms" in headers: + request_total_ms.append(float(headers["x-request-total-ms"])) + if "x-worker-total-ms" in headers: + worker_total_ms.append(float(headers["x-worker-total-ms"])) + + return { + "concurrency": int(args.concurrency), + "server_pid": args.server_pid, + "request_count": int(len(payloads)), + "wall_ms": float(wall_ms), + "success_count": int(len(ok_results)), + "failure_count": int(len(failed_results)), + "request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None, + "request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None, + "worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None, + "worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None, + "gpu_memory": poller.summary(), + "results": results, + } + + +def main() -> None: + args = parse_args() + output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}" + output_dir.mkdir(parents=True, exist_ok=True) + summary = asyncio.run(run_benchmark(args)) + summary_path = output_dir / "summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps({ + "concurrency": summary["concurrency"], + "success_count": summary["success_count"], + "failure_count": summary["failure_count"], + "wall_ms": summary["wall_ms"], + "gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"], + "request_total_ms_avg": summary["request_total_ms_avg"], + "request_total_ms_max": summary["request_total_ms_max"], + "summary_path": str(summary_path), + }, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/tools/t2s_memory_breakdown.py b/tools/t2s_memory_breakdown.py new file mode 100644 index 00000000..18127953 --- /dev/null +++ b/tools/t2s_memory_breakdown.py @@ -0,0 +1,887 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import gc +import contextlib +import json +import random +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402 +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SRequestState, + T2SRunningRequest, + _build_decode_batch_from_running, + build_prefill_batch, + prepare_request_state, + run_decode_step_for_running, + run_prefill_step, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"]) + parser.add_argument("--auto-count", type=int, default=4) + parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav") + parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt") + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text-lang", type=str, default="zh") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--warmup", action="store_true", default=False) + parser.add_argument("--worker-rounds", type=int, default=1) + parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"]) + parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False) + parser.add_argument( + "--output-dir", + type=Path, + default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1", + ) + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _sync_device(device: Any) -> None: + try: + device_str = str(device) + if device_str.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.synchronize(device) + elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + except Exception: + pass + + +def bytes_to_mb(num_bytes: int) -> float: + return float(num_bytes) / (1024.0 * 1024.0) + + +def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int: + if tensor is None: + return 0 + return int(tensor.numel() * tensor.element_size()) + + +def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int: + return int(sum(tensor_nbytes(item) for item in items)) + + +def model_nbytes(module: torch.nn.Module) -> int: + total = 0 + for parameter in module.parameters(): + total += tensor_nbytes(parameter) + for buffer in module.buffers(): + total += tensor_nbytes(buffer) + return int(total) + + +def build_module_weight_summary(tts: TTS) -> Dict[str, Any]: + modules = { + "t2s_model": tts.t2s_model, + "t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None, + "vits_model": tts.vits_model, + "bert_model": tts.bert_model, + "cnhuhbert_model": tts.cnhuhbert_model, + "vocoder": tts.vocoder, + "sv_model": tts.sv_model, + } + by_module = {} + total_bytes = 0 + for name, module in modules.items(): + module_bytes = model_nbytes(module) if module is not None else 0 + by_module[name] = { + "bytes": int(module_bytes), + "mb": bytes_to_mb(module_bytes), + } + total_bytes += module_bytes + return { + "by_module": by_module, + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + } + + +def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]: + storages: Dict[int, Dict[str, Any]] = {} + tensor_views: List[Dict[str, Any]] = [] + for obj in gc.get_objects(): + try: + tensor = None + if torch.is_tensor(obj): + tensor = obj + elif hasattr(obj, "data") and torch.is_tensor(obj.data): + tensor = obj.data + if tensor is None or not tensor.is_cuda: + continue + storage = tensor.untyped_storage() + storage_ptr = int(storage.data_ptr()) + if storage_ptr not in storages: + storages[storage_ptr] = { + "storage_ptr": storage_ptr, + "storage_bytes": int(storage.nbytes()), + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + "device": str(tensor.device), + } + tensor_views.append( + { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "bytes": tensor_nbytes(tensor), + "device": str(tensor.device), + } + ) + except Exception: + continue + storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True) + tensor_views.sort(key=lambda item: item["bytes"], reverse=True) + return { + "unique_storage_count": int(len(storage_list)), + "unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)), + "unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)), + "top_storages": storage_list[:top_k], + "top_tensor_views": tensor_views[:top_k], + } + + +def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip() + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count] + if len(wav_paths) < args.auto_count: + raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav") + text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()] + if len(text_lines) < args.auto_count: + raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本") + specs: List[SchedulerRequestSpec] = [] + for index, wav_path in enumerate(wav_paths): + lab_path = wav_path.with_suffix(".lab") + if not lab_path.exists(): + raise FileNotFoundError(f"找不到参考文本 {lab_path}") + specs.append( + SchedulerRequestSpec( + request_id=f"req_{index:03d}", + ref_audio_path=wav_path, + prompt_text=lab_path.read_text(encoding="utf-8").strip(), + prompt_lang="zh", + text=text_lines[index], + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ) + return specs + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8").strip() + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + if args.scenario == "single": + return build_single_spec(args) + return build_auto_specs(args) + + +def load_pipeline(config_path: Path) -> TTS: + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def cuda_mem_snapshot(device: Any) -> Dict[str, float]: + if not (str(device).startswith("cuda") and torch.cuda.is_available()): + return { + "allocated_mb": 0.0, + "reserved_mb": 0.0, + "max_allocated_mb": 0.0, + "max_reserved_mb": 0.0, + } + _sync_device(device) + return { + "allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)), + "reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)), + "max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)), + "max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)), + } + + +def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]: + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.reset_peak_memory_stats(device) + before = cuda_mem_snapshot(device) + started = time.perf_counter() + result = fn() + _sync_device(device) + elapsed_ms = (time.perf_counter() - started) * 1000.0 + after = cuda_mem_snapshot(device) + after["elapsed_ms"] = float(elapsed_ms) + after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"]) + after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"]) + after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0)) + return result, after + + +class GlobalPeakRecorder: + def __init__(self, device: Any): + self.device = device + self.checkpoints: List[Dict[str, Any]] = [] + if str(device).startswith("cuda") and torch.cuda.is_available(): + gc.collect() + _sync_device(device) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + def record(self, label: str, **extra: Any) -> None: + snapshot = cuda_mem_snapshot(self.device) + snapshot["label"] = label + snapshot.update(extra) + self.checkpoints.append(snapshot) + + def summary(self) -> Dict[str, Any]: + peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None + return { + "peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]), + "peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]), + "peak_label": None if peak is None else peak["label"], + "checkpoints": self.checkpoints, + } + + +def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]: + per_request = [] + total = { + "phones_bytes": 0, + "prompt_phones_bytes": 0, + "all_phones_bytes": 0, + "all_bert_features_bytes": 0, + "prompt_semantic_bytes": 0, + "refer_spec_bytes": 0, + "raw_audio_bytes": 0, + "audio_16k_bytes": 0, + } + for state in states: + spec_audio, audio_16k = state.refer_spec + item = { + "request_id": state.request_id, + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "phones_len": int(state.phones.shape[0]), + "all_phones_len": int(state.all_phones.shape[0]), + "bert_frames": int(state.all_bert_features.shape[-1]), + "phones_bytes": tensor_nbytes(state.phones), + "prompt_phones_bytes": tensor_nbytes(state.prompt_phones), + "all_phones_bytes": tensor_nbytes(state.all_phones), + "all_bert_features_bytes": tensor_nbytes(state.all_bert_features), + "prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic), + "refer_spec_bytes": tensor_nbytes(spec_audio), + "audio_16k_bytes": tensor_nbytes(audio_16k), + "raw_audio_bytes": tensor_nbytes(state.raw_audio), + } + for key in total: + total[key] += int(item[key]) + per_request.append(item) + total["total_bytes"] = int(sum(total.values())) + total["total_mb"] = bytes_to_mb(total["total_bytes"]) + return {"per_request": per_request, "total": total} + + +def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]: + y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences)) + fields = { + "x_bytes": tensor_nbytes(active_batch.x), + "x_lens_bytes": tensor_nbytes(active_batch.x_lens), + "prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens), + "xy_pos_bytes": tensor_nbytes(active_batch.xy_pos), + "key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask), + "prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask), + "y_sequence_bytes": y_sequence_bytes, + } + fields["total_bytes"] = int(sum(fields.values())) + fields["total_mb"] = bytes_to_mb(fields["total_bytes"]) + fields["batch_size"] = int(len(active_batch.states)) + fields["max_x_len"] = int(active_batch.x.shape[1]) + fields["src_len"] = int(active_batch.xy_pos.shape[1]) + fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape) + return fields + + +def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]: + per_request = [] + total_private_k_bytes = 0 + total_private_v_bytes = 0 + total_decode_mask_bytes = 0 + total_y_sequence_bytes = 0 + for item in running_requests: + k_bytes = tensor_list_nbytes(item.k_cache) + v_bytes = tensor_list_nbytes(item.v_cache) + mask_bytes = tensor_nbytes(item.decode_attn_mask) + y_bytes = tensor_nbytes(item.y_sequence) + total_private_k_bytes += k_bytes + total_private_v_bytes += v_bytes + total_decode_mask_bytes += mask_bytes + total_y_sequence_bytes += y_bytes + per_request.append( + { + "request_id": item.state.request_id, + "step_idx": int(item.step_idx), + "prefix_len": int(item.prefix_len), + "history_len": int(item.y_sequence.shape[0]), + "kv_len": int(item.k_cache[0].shape[1]), + "k_cache_bytes": k_bytes, + "v_cache_bytes": v_bytes, + "decode_mask_bytes": mask_bytes, + "y_sequence_bytes": y_bytes, + } + ) + total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes + return { + "per_request": per_request, + "totals": { + "private_k_cache_bytes": int(total_private_k_bytes), + "private_v_cache_bytes": int(total_private_v_bytes), + "private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes), + "decode_mask_bytes": int(total_decode_mask_bytes), + "y_sequence_bytes": int(total_y_sequence_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + }, + } + + +def summarise_decode_batch( + xy_pos: torch.Tensor, + batched_k_cache: Sequence[torch.Tensor], + batched_v_cache: Sequence[torch.Tensor], + batched_decode_attn_mask: Optional[torch.Tensor], + running_requests: Sequence[T2SRunningRequest], +) -> Dict[str, Any]: + private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests)) + private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests)) + batched_k_bytes = tensor_list_nbytes(batched_k_cache) + batched_v_bytes = tensor_list_nbytes(batched_v_cache) + batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask) + xy_pos_bytes = tensor_nbytes(xy_pos) + total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes + return { + "batch_size": int(len(running_requests)), + "xy_pos_bytes": int(xy_pos_bytes), + "batched_k_cache_bytes": int(batched_k_bytes), + "batched_v_cache_bytes": int(batched_v_bytes), + "batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes), + "batched_decode_mask_bytes": int(batched_mask_bytes), + "private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes), + "kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_pos_shape": list(xy_pos.shape), + "batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape), + "layer_k_cache_shape": list(batched_k_cache[0].shape), + } + + +def summarise_decode_outputs( + xy_dec: torch.Tensor, + next_k_cache: Sequence[torch.Tensor], + next_v_cache: Sequence[torch.Tensor], +) -> Dict[str, Any]: + xy_dec_bytes = tensor_nbytes(xy_dec) + next_k_bytes = tensor_list_nbytes(next_k_cache) + next_v_bytes = tensor_list_nbytes(next_v_cache) + total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes + return { + "xy_dec_bytes": int(xy_dec_bytes), + "next_k_cache_bytes": int(next_k_bytes), + "next_v_cache_bytes": int(next_v_bytes), + "next_kv_cache_bytes": int(next_k_bytes + next_v_bytes), + "total_bytes": int(total_bytes), + "total_mb": bytes_to_mb(total_bytes), + "xy_dec_shape": list(xy_dec.shape), + "layer_next_k_cache_shape": list(next_k_cache[0].shape), + } + + +def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]: + ranking = [ + ("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]), + ("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]), + ("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]), + ("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]), + ("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]), + ("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]), + ("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]), + ] + ranking.sort(key=lambda item: item[1], reverse=True) + return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking] + + +def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]: + semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device) + phones = state.phones.unsqueeze(0).to(tts.configs.device) + audio_fragment = tts.synthesize_audio_request_local( + semantic_tokens=semantic_tokens, + phones=phones, + prompt_semantic=state.prompt_semantic, + prompt_phones=state.prompt_phones, + refer_spec=state.refer_spec, + raw_audio=state.raw_audio, + raw_sr=state.raw_sr, + speed=1.0, + sample_steps=32, + ) + output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"] + return tts.audio_postprocess( + audio=[[audio_fragment]], + sr=int(output_sr), + batch_index_list=None, + speed_factor=1.0, + split_bucket=False, + fragment_interval=0.0, + super_sampling=False, + ) + + +def simulate_worker_end_to_end( + tts: TTS, + specs: Sequence[SchedulerRequestSpec], + max_steps: int, + rounds: int, + grad_mode: str = "default", +) -> Dict[str, Any]: + device = tts.configs.device + recorder = GlobalPeakRecorder(device) + recorder.record("after_model_load") + + state_map: Dict[str, T2SRequestState] = {} + per_round: List[Dict[str, Any]] = [] + + for round_index in range(rounds): + grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext + with grad_context(): + states = [prepare_request_state(tts, spec) for spec in specs] + state_map = {state.request_id: state for state in states} + recorder.record( + "after_prepare_states", + round_index=int(round_index), + request_count=int(len(states)), + grad_mode=grad_mode, + ) + + pending = list(states) + running_requests: List[T2SRunningRequest] = [] + round_events: List[Dict[str, Any]] = [] + current_tick = 0 + + while pending or running_requests: + admitted = pending + pending = [] + + if admitted: + recorder.record( + "before_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_count=int(len(admitted)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps) + recorder.record( + "after_prefill", + round_index=int(round_index), + tick=int(current_tick), + admitted_running_count=int(len(admitted_running)), + admitted_finished_count=int(len(admitted_finished)), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "prefill", + "admitted_count": int(len(admitted)), + "admitted_running_count": int(len(admitted_running)), + "admitted_finished_count": int(len(admitted_finished)), + } + ) + for item in admitted_finished: + recorder.record( + "before_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_prefill_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + running_requests.extend(admitted_running) + recorder.record( + "after_extend_running", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + + if running_requests: + recorder.record( + "before_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + grad_mode=grad_mode, + ) + with grad_context(): + running_requests, step_finished = run_decode_step_for_running( + tts.t2s_model.model, + running_requests, + max_steps=max_steps, + ) + recorder.record( + "after_decode", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_count=int(len(step_finished)), + grad_mode=grad_mode, + ) + round_events.append( + { + "tick": int(current_tick), + "event": "decode", + "running_count_after_decode": int(len(running_requests)), + "finished_count": int(len(step_finished)), + } + ) + for item in step_finished: + recorder.record( + "before_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + semantic_len=int(item.semantic_tokens.shape[0]), + grad_mode=grad_mode, + ) + with grad_context(): + sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens) + recorder.record( + "after_synth_decode_finished", + round_index=int(round_index), + tick=int(current_tick), + running_count=int(len(running_requests)), + finished_request_id=item.request_id, + sample_rate=int(sample_rate), + audio_samples=int(audio_data.shape[0]), + grad_mode=grad_mode, + ) + current_tick += 1 + + recorder.record( + "after_round_complete", + round_index=int(round_index), + running_count=0, + grad_mode=grad_mode, + ) + per_round.append( + { + "round_index": int(round_index), + "events": round_events, + } + ) + + return { + "grad_mode": grad_mode, + "rounds": per_round, + "timeline": recorder.summary(), + } + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + device = tts.configs.device + use_cuda = str(device).startswith("cuda") and torch.cuda.is_available() + set_seed(args.seed, use_cuda) + + specs = load_request_specs(args) + if args.early_stop_num == -1: + for spec in specs: + spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec) + + if args.warmup and specs: + warmup_spec = specs[:1] + _ = [prepare_request_state(tts, spec) for spec in warmup_spec] + gc.collect() + if use_cuda: + torch.cuda.empty_cache() + _sync_device(device) + + states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs]) + request_state_summary = summarise_state_tensors(states) + + active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states)) + prefill_batch_tensor_summary = summarise_prefill_batch(active_batch) + + prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps)) + running_requests, finished_items = prefill_result + running_requests_summary = summarise_running_requests(running_requests) + finished_after_prefill_summary = [ + { + "request_id": item.request_id, + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + "semantic_len": int(item.semantic_tokens.shape[0]), + } + for item in finished_items + ] + + if not running_requests: + raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}") + + decode_batch_result, decode_batch_mem = stage_run( + device, + lambda: _build_decode_batch_from_running(model, running_requests), + ) + xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result + decode_batch_tensor_summary = summarise_decode_batch( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + running_requests, + ) + + decode_out_result, decode_step_mem = stage_run( + device, + lambda: model.t2s_transformer.decode_next_token( + xy_pos, + batched_k_cache, + batched_v_cache, + batched_decode_attn_mask, + ), + ) + xy_dec, next_k_cache, next_v_cache = decode_out_result + decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache) + del active_batch + del running_requests + del finished_items + del xy_pos + del batched_k_cache + del batched_v_cache + del batched_decode_attn_mask + del xy_dec + del next_k_cache + del next_v_cache + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + end_to_end_worker = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode=args.worker_grad_mode, + ) + live_cuda_tensors_after_worker = snapshot_live_cuda_tensors() + worker_inference_mode = None + if args.compare_worker_grad_modes: + gc.collect() + if use_cuda: + _sync_device(device) + torch.cuda.empty_cache() + worker_inference_mode = simulate_worker_end_to_end( + tts=tts, + specs=specs, + max_steps=args.max_steps, + rounds=args.worker_rounds, + grad_mode="inference_mode", + ) + + summary = { + "meta": { + "scenario": args.scenario if args.request_manifest is None else "manifest", + "seed": int(args.seed), + "device": str(device), + "dtype": str(next(model.parameters()).dtype), + "request_count": int(len(specs)), + "num_layers": int(model.num_layers), + "num_heads": int(model.num_head), + "model_dim": int(model.model_dim), + "model_weights_mb": bytes_to_mb(model_nbytes(model)), + }, + "loaded_module_weights": build_module_weight_summary(tts), + "requests": [ + { + "request_id": spec.request_id, + "ref_audio_path": str(spec.ref_audio_path), + "prompt_text": spec.prompt_text, + "text": spec.text, + } + for spec in specs + ], + "prepare_stage": { + "memory": prepare_mem, + "request_state": request_state_summary, + }, + "prefill_batch": { + "memory": prefill_batch_mem, + "tensor_bytes": prefill_batch_tensor_summary, + }, + "prefill_step": { + "memory": prefill_step_mem, + "running_requests": running_requests_summary, + "finished_after_prefill": finished_after_prefill_summary, + }, + "decode_batch": { + "memory": decode_batch_mem, + "tensor_bytes": decode_batch_tensor_summary, + }, + "decode_outputs": { + "memory": decode_step_mem, + "tensor_bytes": decode_output_tensor_summary, + }, + "end_to_end_worker": end_to_end_worker, + "live_cuda_tensors_after_worker": live_cuda_tensors_after_worker, + "end_to_end_worker_inference_mode": worker_inference_mode, + } + summary["top_rankings"] = top_rankings(summary) + + summary_path = args.output_dir / "t2s_memory_breakdown_summary.json" + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + + print(json.dumps(summary["meta"], ensure_ascii=False, indent=2)) + print("[top_rankings]") + for item in summary["top_rankings"]: + print(f"- {item['name']}: {item['mb']:.3f} MB") + print("[worker_peak]") + print( + json.dumps( + { + "peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"], + "peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + if worker_inference_mode is not None: + print("[worker_peak_inference_mode]") + print( + json.dumps( + { + "peak_label": worker_inference_mode["timeline"]["peak_label"], + "peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"], + "peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"], + }, + ensure_ascii=False, + indent=2, + ) + ) + print(f"[summary] {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/t2s_scheduler_prototype.py b/tools/t2s_scheduler_prototype.py new file mode 100644 index 00000000..cd4b9c6d --- /dev/null +++ b/tools/t2s_scheduler_prototype.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import argparse +import json +import random +import sys +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import torch + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.append(str(ROOT_DIR)) +gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS" +if str(gpt_sovits_dir) not in sys.path: + sys.path.append(str(gpt_sovits_dir)) + +from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402 + SchedulerRequestSpec, + T2SFinishedItem, + T2SRequestState, + prepare_request_state, + run_scheduler_continuous, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.") + parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml") + parser.add_argument("--request-manifest", type=Path, default=None) + parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav") + parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。") + parser.add_argument("--prompt-lang", type=str, default="zh") + parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt") + parser.add_argument("--text", type=str, default=None) + parser.add_argument("--text-lang", type=str, default="en") + parser.add_argument("--top-k", type=int, default=15) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.35) + parser.add_argument("--early-stop-num", type=int, default=-1) + parser.add_argument("--max-steps", type=int, default=1500) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run") + return parser.parse_args() + + +def set_seed(seed: int, use_cuda: bool) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_pipeline(config_path: Path): + try: + from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。" + ) from exc + tts_config = TTS_Config(str(config_path)) + print(tts_config) + return TTS(tts_config) + + +def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]: + if args.request_manifest is not None: + payload = json.loads(args.request_manifest.read_text(encoding="utf-8")) + raw_requests = payload["requests"] if isinstance(payload, dict) else payload + specs: List[SchedulerRequestSpec] = [] + for index, item in enumerate(raw_requests): + text = item.get("text") + text_file = item.get("text_file") + if text is None and text_file is None: + raise ValueError(f"request[{index}] must provide text or text_file") + if text is None: + text = Path(text_file).read_text(encoding="utf-8") + specs.append( + SchedulerRequestSpec( + request_id=item.get("request_id", f"req_{index:03d}"), + ref_audio_path=Path(item["ref_audio_path"]), + prompt_text=item["prompt_text"], + prompt_lang=item.get("prompt_lang", "zh"), + text=text, + text_lang=item.get("text_lang", "zh"), + top_k=int(item.get("top_k", args.top_k)), + top_p=float(item.get("top_p", args.top_p)), + temperature=float(item.get("temperature", args.temperature)), + repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)), + early_stop_num=int(item.get("early_stop_num", args.early_stop_num)), + ready_step=int(item.get("ready_step", 0)), + ) + ) + return specs + + text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8") + return [ + SchedulerRequestSpec( + request_id="req_000", + ref_audio_path=args.ref_audio, + prompt_text=args.prompt_text, + prompt_lang=args.prompt_lang, + text=text, + text_lang=args.text_lang, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + early_stop_num=args.early_stop_num, + ready_step=0, + ) + ] + + +def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]: + return [ + { + "request_id": state.request_id, + "ready_step": int(state.ready_step), + "ref_audio_path": str(state.ref_audio_path), + "prompt_semantic_len": int(state.prompt_semantic.shape[0]), + "all_phone_len": int(state.all_phones.shape[0]), + "bert_len": int(state.all_bert_features.shape[-1]), + "norm_text": state.norm_text, + } + for state in states + ] + + +def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]: + return [ + { + "request_id": item.request_id, + "semantic_len": int(item.semantic_tokens.shape[0]), + "finish_idx": int(item.finish_idx), + "finish_reason": item.finish_reason, + } + for item in items + ] + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + tts = load_pipeline(args.config) + model = tts.t2s_model.model + use_cuda = str(tts.configs.device).startswith("cuda") + set_seed(args.seed, use_cuda) + + request_specs = load_request_specs(args) + states = [prepare_request_state(tts, spec) for spec in request_specs] + finished = run_scheduler_continuous(model, states, max_steps=args.max_steps) + + summary = { + "request_count": len(states), + "max_steps": args.max_steps, + "requests": summarise_requests(states), + "finished": summarise_finished(finished), + } + output_path = args.output_dir / "scheduler_prototype_summary.json" + output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(summary, ensure_ascii=False, indent=2)) + print(f"[saved] {output_path}") + + +if __name__ == "__main__": + try: + main() + except ModuleNotFoundError as exc: + print(f"[error] {exc}") + raise SystemExit(1) from None