Merge 8a444c10b72aadeb1f8be9210b45f2a519d383e4 into 2d9193b0d3c0eae0c3a14d8c68a839f1bae157dc

This commit is contained in:
白菜工厂1145号员工 2026-03-13 08:46:05 +00:00 committed by GitHub
commit d54ea4d5aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 15450 additions and 617 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "third_party/g2pw-cu"]
path = third_party/g2pw-cu
url = https://github.com/baicai-1145/g2pw-cu.git

View File

@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module):
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
self.last_infer_stats = {}
def _set_last_infer_stats(self, stats):
self.last_infer_stats = stats
def get_last_infer_stats(self):
return dict(self.last_infer_stats)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module):
repetition_penalty: float = 1.35,
**kwargs,
):
requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True))
if prompts is None:
self._set_last_infer_stats(
{
"infer_mode": "batch_infer_prompt_free_fallback",
"requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath,
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": None,
"fastpath_hit": False,
"generated_token_count": 0,
"generated_token_count_list": [],
}
)
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(
x,
@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module):
)
max_len = kwargs.get("max_len", x_lens.max())
enable_mask_free_fastpath = requested_enable_mask_free_fastpath
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module):
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None] * y.shape[0]
decode_attn_mask = attn_mask
prefill_after_mask_all_visible = None
fastpath_hit = False
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(
xy_pos, k_cache, v_cache, decode_attn_mask
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
prefill_after_mask_all_visible = not attn_mask.any().item()
if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible:
decode_attn_mask = None
fastpath_hit = True
else:
decode_attn_mask = attn_mask
else:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
if decode_attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
decode_attn_mask = attn_mask
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module):
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if decode_attn_mask is not None:
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
decode_attn_mask = attn_mask
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module):
if idx_list[i] is None:
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
self._set_last_infer_stats(
{
"infer_mode": "batch_infer",
"requested_enable_mask_free_fastpath": enable_mask_free_fastpath,
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": prefill_after_mask_all_visible,
"fastpath_hit": fastpath_hit,
"generated_token_count": int(sum(idx_list)),
"generated_token_count_list": [int(item) for item in idx_list],
"max_len": int(max_len),
}
)
if ref_free:
return y_list, [0] * x.shape[0]
# print(idx_list)
@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module):
y_list.append(y[0])
idx_list.append(idx)
self._set_last_infer_stats(
{
"infer_mode": "naive_batched",
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": None,
"fastpath_hit": False,
"generated_token_count": int(sum(idx_list)),
"generated_token_count_list": [int(item) for item in idx_list],
}
)
return y_list, idx_list
def infer_panel_naive(
@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module):
if not streaming_mode:
generated_token_count = max(int(y.shape[1] - prefix_len), 0)
self._set_last_infer_stats(
{
"infer_mode": "naive",
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
"batch_size": int(x.shape[0]),
"prefill_after_mask_all_visible": True if prompts is not None else None,
"fastpath_hit": True if prompts is not None else False,
"generated_token_count": generated_token_count,
"generated_token_count_list": [generated_token_count],
}
)
if ref_free:
yield y, 0
yield y, idx

View File

@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync(
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
previous_token_mask: Optional[torch.Tensor] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
@ -158,13 +159,27 @@ def logits_to_probs(
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if previous_token_mask is None:
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
else:
previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device)
if previous_token_mask.any():
batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens)
valid_batch_index = batch_index[previous_token_mask]
valid_token_index = previous_tokens[previous_token_mask]
score = logits[valid_batch_index, valid_token_index]
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits[valid_batch_index, valid_token_index] = score
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@ -192,9 +207,15 @@ def logits_to_probs(
def sample(
logits,
previous_tokens: Optional[torch.Tensor] = None,
previous_token_mask: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
probs = logits_to_probs(
logits=logits,
previous_tokens=previous_tokens,
previous_token_mask=previous_token_mask,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,10 @@
import asyncio
import os
import sys
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass
from tqdm import tqdm
@ -11,11 +15,13 @@ import re
import torch
from text.LangSegmenter import LangSegmenter
from text import chinese
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload
from tools.i18n.i18n import I18nAuto, scan_language_list
@ -49,12 +55,80 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list:
return result
class StageLimiter:
def __init__(self, slots: int):
self.slots = max(1, int(slots))
self.semaphore = threading.BoundedSemaphore(self.slots)
self.lock = threading.Lock()
self.inflight = 0
self.peak_inflight = 0
@contextmanager
def enter(self):
wait_start = time.perf_counter()
self.semaphore.acquire()
wait_ms = (time.perf_counter() - wait_start) * 1000.0
with self.lock:
self.inflight += 1
current_inflight = self.inflight
if current_inflight > self.peak_inflight:
self.peak_inflight = current_inflight
peak_inflight = self.peak_inflight
try:
yield {
"wait_ms": wait_ms,
"inflight": current_inflight,
"peak_inflight": peak_inflight,
"slots": self.slots,
}
finally:
with self.lock:
self.inflight = max(0, self.inflight - 1)
self.semaphore.release()
def snapshot(self) -> Dict[str, int]:
with self.lock:
return {
"slots": self.slots,
"inflight": self.inflight,
"peak_inflight": self.peak_inflight,
}
@dataclass
class PreparedTextSegment:
language: str
phones: List[int]
word2ph: Optional[List[int]]
norm_text: str
needs_g2pw: bool = False
class TextPreprocessor:
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
def __init__(
self,
bert_model: AutoModelForMaskedLM,
tokenizer: AutoTokenizer,
device: torch.device,
version: str = "v2",
bert_stage_limiter: StageLimiter | None = None,
bert_batch_worker: PrepareBertBatchWorker | None = None,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.bert_lock = threading.RLock()
self.version = str(version)
self.bert_stage_limiter = bert_stage_limiter
self.bert_batch_worker = bert_batch_worker
def snapshot(self) -> Dict[str, object]:
return {
"device": str(self.device),
"bert_stage_limiter": (
None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot())
),
"bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()),
}
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
@ -98,7 +172,7 @@ class TextPreprocessor:
# 解决输入目标文本的空行导致报错的问题
if len(text.strip()) == 0:
continue
if not re.sub("\W+", "", text):
if not re.sub(r"\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if text[-1] not in splits:
@ -115,86 +189,233 @@ class TextPreprocessor:
return texts
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version)
prepared_segments = self.preprocess_text_segments(text, language, version)
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock:
text = re.sub(r' {2,}', ' ', text)
textlist = []
langlist = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
textlist = []
langlist = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text, "zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
)
if same_group:
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
else:
langlist.append(language)
textlist.append(tmp["text"])
return textlist, langlist
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True)
def get_phones_and_bert(
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
):
prepared_segments = self.preprocess_text_segments(text, language, version, final=final)
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
return phones, bert, norm_text
def preprocess_text_segments(
self,
text: str,
language: str,
version: str,
final: bool = False,
) -> List[PreparedTextSegment]:
payloads = preprocess_text_segments_payload(text, language, version, final=final)
return [
PreparedTextSegment(
language=str(payload["language"]),
phones=list(payload["phones"]),
word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]),
norm_text=str(payload["norm_text"]),
needs_g2pw=bool(payload.get("needs_g2pw", False)),
)
for payload in payloads
]
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
def resolve_g2pw_segments(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> List[PreparedTextSegment]:
zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)]
if not zh_indices:
return prepared_segments
from text import chinese2
normalized_segments = [prepared_segments[index].norm_text for index in zh_indices]
resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True)
self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0))
self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0))
self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0))
self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0))
self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0))
self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0))
self._accumulate_profile(
profile,
"g2pw_runtime_collect_wait_ms",
g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0),
)
self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0))
self._update_profile_peak(
profile,
"g2pw_runtime_batch_rows_peak",
g2pw_profile.get("g2pw_runtime_batch_rows", 0.0),
)
self._update_profile_peak(
profile,
"g2pw_runtime_batch_requests_peak",
g2pw_profile.get("g2pw_runtime_batch_requests", 0.0),
)
self._update_profile_peak(
profile,
"g2pw_runtime_pool_workers",
g2pw_profile.get("g2pw_runtime_pool_workers", 0.0),
)
for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments):
prepared_segments[index] = PreparedTextSegment(
language=prepared_segments[index].language,
phones=list(cleaned_text_to_sequence(phones, self.version)),
word2ph=None if word2ph is None else list(word2ph),
norm_text=str(norm_text),
needs_g2pw=False,
)
return prepared_segments
def build_phones_and_bert_from_segments(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
phones_list: List[List[int]] = []
bert_list: List[torch.Tensor] = []
norm_text_list: List[str] = []
for segment in prepared_segments:
bert = self.get_bert_inf(
segment.phones,
segment.word2ph,
segment.norm_text,
segment.language,
profile=profile,
)
phones_list.append(segment.phones)
norm_text_list.append(segment.norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
return phones, bert, norm_text
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
if profile is None:
return
profile[key] = float(profile.get(key, 0.0)) + float(value)
def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None:
if profile is None:
return
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None:
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0))
self._accumulate_profile(
profile,
"bert_batch_collect_wait_ms",
worker_profile.get("bert_batch_collect_wait_ms", 0.0),
)
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0))
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
self._update_profile_peak(
profile,
"bert_pending_depth_on_enqueue_peak",
worker_profile.get("bert_pending_depth_on_enqueue", 0.0),
)
self._update_profile_peak(
profile,
"bert_pending_depth_on_collect_peak",
worker_profile.get("bert_pending_depth_on_collect", 0.0),
)
self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0))
if profile is not None:
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0))
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
if self.bert_batch_worker is not None:
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
self._merge_bert_worker_profile(profile, worker_profile)
return feature
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
if self.bert_stage_limiter is None:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.bert_stage_limiter.enter() as limiter_stats:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"])
self._accumulate_profile(profile, "bert_forward_ms", forward_ms)
self._accumulate_profile(profile, "bert_calls", 1.0)
self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"])
if profile is not None:
profile["bert_stage_slots"] = float(limiter_stats["slots"])
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
@ -209,10 +430,19 @@ class TextPreprocessor:
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
def get_bert_inf(
self,
phones: list,
word2ph: Optional[list],
norm_text: str,
language: str,
profile: Dict | None = None,
):
language = language.replace("all_", "")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
if word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
else:
feature = torch.zeros(
(1024, len(phones)),
@ -221,6 +451,115 @@ class TextPreprocessor:
return feature
async def build_phones_and_bert_from_segments_async(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
segment_jobs = self._build_async_segment_jobs(prepared_segments, profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
for segment_index, segment in enumerate(prepared_segments):
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
continue
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
pending_items.append(
(
segment_jobs["bert_list"],
segment_index,
profile,
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
)
)
if pending_items:
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
self._merge_bert_worker_profile(item_profile, worker_profile)
bert_list[bert_index] = feature.to(self.device)
return self._finalize_async_segment_jobs(segment_jobs)
def _build_async_segment_jobs(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None,
) -> Dict[str, List]:
phones_list: List[List[int]] = []
bert_list: List[torch.Tensor | None] = []
norm_text_list: List[str] = []
for segment in prepared_segments:
phones_list.append(segment.phones)
norm_text_list.append(segment.norm_text)
segment_language = segment.language.replace("all_", "")
if segment_language == "zh" and self.bert_batch_worker is not None:
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
bert_list.append(None)
continue
bert_list.append(
self.get_bert_inf(
segment.phones,
segment.word2ph,
segment.norm_text,
segment.language,
profile=profile,
)
)
return {
"phones_list": phones_list,
"bert_list": bert_list,
"norm_text_list": norm_text_list,
}
@staticmethod
def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]:
bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1)
phones = sum(segment_jobs["phones_list"], [])
norm_text = "".join(segment_jobs["norm_text_list"])
return phones, bert, norm_text
async def build_phones_and_bert_pair_from_segments_async(
self,
prompt_segments: List[PreparedTextSegment],
target_segments: List[PreparedTextSegment],
prompt_profile: Dict | None = None,
target_profile: Dict | None = None,
) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]:
prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile)
target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile)
prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile)
target_jobs = self._build_async_segment_jobs(target_segments, target_profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
for segment_jobs, prepared_segments, profile in (
(prompt_jobs, prompt_segments, prompt_profile),
(target_jobs, target_segments, target_profile),
):
for segment_index, segment in enumerate(prepared_segments):
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
continue
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
pending_items.append(
(
segment_jobs["bert_list"],
segment_index,
profile,
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
)
)
if pending_items:
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
self._merge_bert_worker_profile(profile, worker_profile)
bert_list[bert_index] = feature.to(self.device)
return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs)
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):
@ -236,4 +575,4 @@ class TextPreprocessor:
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result
return result

View File

@ -1 +1,11 @@
from . import TTS, text_segmentation_method
from __future__ import annotations
import importlib
__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"]
def __getattr__(name: str):
if name in __all__:
return importlib.import_module(f"{__name__}.{name}")
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,346 @@
import asyncio
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple
import torch
@dataclass
class BertFeatureTask:
norm_text: str
word2ph: List[int]
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
enqueued_at: float = 0.0
admission_wait_ms: float = 0.0
pending_depth_on_enqueue: int = 0
done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None
result_feature: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareBertBatchWorker:
def __init__(
self,
bert_model,
tokenizer,
device,
stage_limiter=None,
batch_window_ms: int = 5,
max_batch_items: int = 16,
max_batch_tokens: int = 4096,
max_pending_tasks: int = 0,
admission_poll_ms: int = 1,
high_pressure_pending_threshold: int = 0,
high_pressure_batch_window_ms: int | None = None,
high_pressure_max_batch_items: int | None = None,
high_pressure_max_batch_tokens: int | None = None,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.stage_limiter = stage_limiter
self.batch_window_ms = max(0, int(batch_window_ms))
self.batch_window_s = float(self.batch_window_ms) / 1000.0
self.max_batch_items = max(1, int(max_batch_items))
self.max_batch_tokens = max(16, int(max_batch_tokens))
self.max_pending_tasks = max(0, int(max_pending_tasks))
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
self.high_pressure_pending_threshold = max(
0,
int(high_pressure_pending_threshold)
if int(high_pressure_pending_threshold) > 0
else max(self.max_batch_items * 2, 32),
)
hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms)
hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items)
hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens)
self.high_pressure_batch_window_ms = max(0, hp_window_ms)
self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0
self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items)
self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens)
self.condition = threading.Condition()
self.pending_tasks: Deque[BertFeatureTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.total_batches = 0
self.active_batch_size = 0
self.active_batch_peak = 0
self.active_batch_tokens = 0
self.active_batch_tokens_peak = 0
self.high_pressure_batches = 0
self.admission_wait_total_ms = 0.0
self.admission_wait_peak_ms = 0.0
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
self.worker_thread.start()
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
return max(1, len(task.norm_text) + 2)
def _can_enqueue_locked(self) -> bool:
if self.max_pending_tasks <= 0:
return True
return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks
def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None:
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
task.enqueued_at = time.perf_counter()
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
self.pending_tasks.append(task)
self.total_submitted += 1
self.admission_wait_total_ms += task.admission_wait_ms
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
def _enqueue_task(self, task: BertFeatureTask) -> None:
admission_started = time.perf_counter()
with self.condition:
while not self._can_enqueue_locked():
self.condition.wait(timeout=self.admission_poll_s)
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
async def _enqueue_task_async(self, task: BertFeatureTask) -> None:
admission_started = time.perf_counter()
while True:
with self.condition:
if self._can_enqueue_locked():
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
return
await asyncio.sleep(self.admission_poll_s)
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
self._enqueue_task(task)
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.result_feature is not None
return task.result_feature, dict(task.profile)
async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
loop = asyncio.get_running_loop()
task = BertFeatureTask(
norm_text=str(norm_text),
word2ph=list(word2ph),
done_loop=loop,
done_future=loop.create_future(),
)
await self._enqueue_task_async(task)
return await task.done_future
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
"pending": len(self.pending_tasks),
"pending_peak": self.pending_peak,
"total_submitted": self.total_submitted,
"total_finished": self.total_finished,
"total_batches": self.total_batches,
"active_batch_size": self.active_batch_size,
"active_batch_peak": self.active_batch_peak,
"active_batch_tokens": self.active_batch_tokens,
"active_batch_tokens_peak": self.active_batch_tokens_peak,
"batch_window_ms": int(self.batch_window_s * 1000.0),
"max_batch_items": self.max_batch_items,
"max_batch_tokens": self.max_batch_tokens,
"max_pending_tasks": self.max_pending_tasks,
"high_pressure_pending_threshold": self.high_pressure_pending_threshold,
"high_pressure_batch_window_ms": self.high_pressure_batch_window_ms,
"high_pressure_max_batch_items": self.high_pressure_max_batch_items,
"high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens,
"high_pressure_batches": self.high_pressure_batches,
"admission_wait_total_ms": self.admission_wait_total_ms,
"admission_wait_peak_ms": self.admission_wait_peak_ms,
}
def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]:
pending_depth = len(self.pending_tasks)
use_high_pressure = (
self.high_pressure_pending_threshold > 0
and pending_depth >= self.high_pressure_pending_threshold
)
if use_high_pressure:
return (
self.high_pressure_batch_window_s,
self.high_pressure_max_batch_items,
self.high_pressure_max_batch_tokens,
True,
pending_depth,
)
return (
self.batch_window_s,
self.max_batch_items,
self.max_batch_tokens,
False,
pending_depth,
)
def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
collect_started = time.perf_counter()
batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = (
self._select_batch_policy_locked()
)
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
batch_tokens = self._estimate_task_tokens(batch[0])
deadline = time.perf_counter() + batch_window_s
while len(batch) < max_batch_items:
remaining = deadline - time.perf_counter()
if remaining <= 0:
break
if not self.pending_tasks:
self.condition.wait(timeout=remaining)
continue
next_task = self.pending_tasks[0]
next_tokens = self._estimate_task_tokens(next_task)
if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens:
break
batch.append(self.pending_tasks.popleft())
batch_tokens += next_tokens
self.active_batch_size = len(batch)
self.active_batch_tokens = batch_tokens
if self.active_batch_size > self.active_batch_peak:
self.active_batch_peak = self.active_batch_size
if self.active_batch_tokens > self.active_batch_tokens_peak:
self.active_batch_tokens_peak = self.active_batch_tokens
if use_high_pressure:
self.high_pressure_batches += 1
return batch, {
"collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0,
"batch_tokens": float(batch_tokens),
"pending_depth_on_collect": float(pending_depth_on_collect),
"high_pressure_mode": 1.0 if use_high_pressure else 0.0,
"batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms),
}
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
with self.condition:
self.active_batch_size = 0
self.active_batch_tokens = 0
self.total_batches += 1
self.total_finished += len(batch)
self.condition.notify_all()
def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None:
batch_started = time.perf_counter()
texts = [task.norm_text for task in batch]
batch_tokens = int(batch_meta["batch_tokens"])
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
if self.stage_limiter is None:
tokenize_start = time.perf_counter()
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
attention_mask_cpu = inputs["attention_mask"].cpu()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
forward_start = time.perf_counter()
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.stage_limiter.enter() as limiter_stats:
tokenize_start = time.perf_counter()
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
attention_mask_cpu = inputs["attention_mask"].cpu()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
forward_start = time.perf_counter()
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
hidden = outputs["hidden_states"][-3].detach().cpu()
scatter_start = time.perf_counter()
for batch_index, task in enumerate(batch):
try:
text_len = len(task.word2ph)
if text_len != len(task.norm_text):
raise AssertionError(
f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}"
)
seq_len = int(attention_mask_cpu[batch_index].sum().item())
char_features = hidden[batch_index, 1 : seq_len - 1]
if char_features.shape[0] != text_len:
raise AssertionError(
f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}"
)
phone_level_feature = []
for char_index, repeat_count in enumerate(task.word2ph):
phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1))
task.result_feature = torch.cat(phone_level_feature, dim=0).T
task.profile = {
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
"bert_admission_wait_ms": float(task.admission_wait_ms),
"bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0),
"bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]),
"bert_forward_ms": float(forward_ms),
"bert_tokenize_ms": float(tokenize_ms),
"bert_scatter_ms": 0.0,
"bert_calls": 1.0,
"bert_stage_slots": float(limiter_stats["slots"]),
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"bert_batch_size": float(len(batch)),
"bert_batch_tokens": float(batch_tokens),
"bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
"bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]),
"bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]),
"bert_batch_window_ms": float(batch_meta["batch_window_ms"]),
}
except Exception as exc: # noqa: PERF203
task.error = exc
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
for task in batch:
if task.result_feature is not None:
task.profile["bert_scatter_ms"] = float(scatter_ms)
task.done_event.set()
self._notify_done_future(task)
@staticmethod
def _resolve_done_future(task: BertFeatureTask) -> None:
if task.done_future is None or task.done_future.done():
return
if task.error is not None:
task.done_future.set_exception(task.error)
return
assert task.result_feature is not None
task.done_future.set_result((task.result_feature, dict(task.profile)))
def _notify_done_future(self, task: BertFeatureTask) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
except RuntimeError:
pass
def _run_loop(self) -> None:
while True:
batch, batch_meta = self._collect_batch()
try:
self._run_batch(batch, batch_meta)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
task.done_event.set()
self._notify_done_future(task)
finally:
self._finalize_batch(batch)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,382 @@
import asyncio
import os
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple
import torch
import torchaudio
REF_AUDIO_MIN_SAMPLES_16K = 48000
REF_AUDIO_MAX_SAMPLES_16K = 160000
_RESAMPLE_CACHE_LOCK = threading.Lock()
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {}
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
device_key = str(device)
key = (int(orig_sr), int(target_sr), device_key)
with _RESAMPLE_CACHE_LOCK:
transform = _RESAMPLE_CACHE.get(key)
if transform is None:
transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key)
_RESAMPLE_CACHE[key] = transform
return transform
def _get_resample_stream(device: str) -> torch.cuda.Stream:
device_key = str(device)
with _RESAMPLE_CACHE_LOCK:
stream = _RESAMPLE_STREAM_CACHE.get(device_key)
if stream is None:
stream = torch.cuda.Stream(device=device_key)
_RESAMPLE_STREAM_CACHE[device_key] = stream
return stream
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
if resample_device not in {"cpu", "cuda"}:
resample_device = "cpu"
if resample_device == "cuda" and not torch.cuda.is_available():
resample_device = "cpu"
wav_mono = raw_audio
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
wav_mono = wav_mono.mean(0, keepdim=True)
if resample_device == "cuda":
stream = _get_resample_stream(resample_device)
with torch.cuda.stream(stream):
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
stream.synchronize()
wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous()
else:
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
if raw_sr != 16000:
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
wav16k = wav16k.squeeze(0).contiguous()
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
raise OSError("参考音频在3~10秒范围外请更换")
if zero_wav_samples > 0:
wav16k = torch.cat(
[wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)],
dim=0,
)
return wav16k.contiguous()
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:
if conv1d is None:
return input_lengths.to(dtype=torch.long)
kernel_size = int(conv1d.kernel_size[0])
stride = int(conv1d.stride[0])
padding = int(conv1d.padding[0])
dilation = int(conv1d.dilation[0])
output_lengths = torch.div(
input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
stride,
rounding_mode="floor",
) + 1
return torch.clamp(output_lengths, min=0).to(dtype=torch.long)
@dataclass
class RefSemanticTask:
raw_audio: torch.Tensor
raw_sr: int
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
batch_popped_at: float = 0.0
done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None
result_prompt_semantic: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareRefSemanticBatchWorker:
def __init__(
self,
ssl_model,
vits_model,
device,
is_half: bool,
zero_wav_samples: int,
stage_limiter=None,
batch_window_ms: int = 5,
max_batch_items: int = 8,
max_batch_samples: int = 960000,
):
self.ssl_model = ssl_model
self.vits_model = vits_model
self.device = device
self.is_half = bool(is_half)
self.zero_wav_samples = max(0, int(zero_wav_samples))
self.stage_limiter = stage_limiter
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
self.max_batch_items = max(1, int(max_batch_items))
self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples))
self.condition = threading.Condition()
self.pending_tasks: Deque[RefSemanticTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.total_batches = 0
self.active_batch_size = 0
self.active_batch_peak = 0
self.active_batch_samples = 0
self.active_batch_samples_peak = 0
self.worker_thread = threading.Thread(
target=self._run_loop,
name="prepare-ref-semantic-batch-worker",
daemon=True,
)
self.worker_thread.start()
def _estimate_task_samples(self, task: RefSemanticTask) -> int:
raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0
base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr))))
return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples
def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr))
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.result_prompt_semantic is not None
return task.result_prompt_semantic, dict(task.profile)
async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
loop = asyncio.get_running_loop()
task = RefSemanticTask(
raw_audio=raw_audio,
raw_sr=int(raw_sr),
done_loop=loop,
done_future=loop.create_future(),
)
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
return await task.done_future
@staticmethod
def _resolve_done_future(task: RefSemanticTask) -> None:
if task.done_future is None or task.done_future.done():
return
if task.error is not None:
task.done_future.set_exception(task.error)
return
assert task.result_prompt_semantic is not None
task.done_future.set_result((task.result_prompt_semantic, dict(task.profile)))
def _notify_task_done(self, task: RefSemanticTask) -> None:
task.done_event.set()
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
except RuntimeError:
pass
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
"pending": len(self.pending_tasks),
"pending_peak": self.pending_peak,
"total_submitted": self.total_submitted,
"total_finished": self.total_finished,
"total_batches": self.total_batches,
"active_batch_size": self.active_batch_size,
"active_batch_peak": self.active_batch_peak,
"active_batch_samples": self.active_batch_samples,
"active_batch_samples_peak": self.active_batch_samples_peak,
"batch_window_ms": int(self.batch_window_s * 1000.0),
"max_batch_items": self.max_batch_items,
"max_batch_samples": self.max_batch_samples,
}
def _collect_batch(self) -> tuple[List[RefSemanticTask], float]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
first_task = self.pending_tasks.popleft()
first_task.batch_popped_at = time.perf_counter()
batch: List[RefSemanticTask] = [first_task]
batch_samples = self._estimate_task_samples(batch[0])
deadline = time.perf_counter() + self.batch_window_s
while len(batch) < self.max_batch_items:
remaining = deadline - time.perf_counter()
if remaining <= 0:
break
if not self.pending_tasks:
self.condition.wait(timeout=remaining)
continue
next_task = self.pending_tasks[0]
next_samples = self._estimate_task_samples(next_task)
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
break
popped_task = self.pending_tasks.popleft()
popped_task.batch_popped_at = time.perf_counter()
batch.append(popped_task)
batch_samples += next_samples
self.active_batch_size = len(batch)
self.active_batch_samples = batch_samples
if self.active_batch_size > self.active_batch_peak:
self.active_batch_peak = self.active_batch_size
if self.active_batch_samples > self.active_batch_samples_peak:
self.active_batch_samples_peak = self.active_batch_samples
return batch, time.perf_counter()
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
with self.condition:
self.active_batch_size = 0
self.active_batch_samples = 0
self.total_batches += 1
self.total_finished += len(batch)
def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor:
model = self.ssl_model.model
if hasattr(model, "_get_feature_vector_attention_mask"):
feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask)
return feature_mask.to(dtype=torch.long).sum(dim=1)
raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1)
if hasattr(model, "_get_feat_extract_output_lengths"):
return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long)
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
@torch.inference_mode()
def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None:
batch_started = time.perf_counter()
prepared_start = time.perf_counter()
prepared_wavs = [
prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch
]
cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0
wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long)
batch_samples = int(wav_lengths.sum().item())
max_wav_len = int(wav_lengths.max().item())
pack_start = time.perf_counter()
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
for batch_index, wav in enumerate(prepared_wavs):
wav_len = int(wav.shape[0])
input_values_cpu[batch_index, :wav_len] = wav
attention_mask_cpu[batch_index, :wav_len] = 1
pack_ms = (time.perf_counter() - pack_start) * 1000.0
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
h2d_ms = 0.0
ssl_forward_ms = 0.0
hidden_length_ms = 0.0
extract_latent_ms = 0.0
if self.stage_limiter is None:
h2d_start = time.perf_counter()
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
ssl_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_length_start = time.perf_counter()
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
latent_start = time.perf_counter()
codes = self.vits_model.extract_latent(hubert_feature)
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
else:
with self.stage_limiter.enter() as limiter_stats:
h2d_start = time.perf_counter()
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
ssl_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_length_start = time.perf_counter()
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
latent_start = time.perf_counter()
codes = self.vits_model.extract_latent(hubert_feature)
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms)
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
scatter_start = time.perf_counter()
for batch_index, task in enumerate(batch):
try:
code_len = int(code_lengths[batch_index].item())
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0)
batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0)
stage_limiter_wait_ms = float(limiter_stats["wait_ms"])
task.profile = {
"prompt_semantic_wait_ms": worker_queue_wait_ms
+ batch_collect_wait_ms
+ stage_limiter_wait_ms,
"prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms,
"prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms,
"prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms,
"prompt_semantic_batch_dispatch_delay_ms": max(
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
),
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_pack_ms": float(pack_ms),
"prompt_semantic_h2d_ms": float(h2d_ms),
"prompt_semantic_ssl_forward_ms": float(ssl_forward_ms),
"prompt_semantic_hidden_length_ms": float(hidden_length_ms),
"prompt_semantic_extract_latent_ms": float(extract_latent_ms),
"prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_calls": 1.0,
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"prompt_semantic_batch_size": float(len(batch)),
"prompt_semantic_batch_samples": float(batch_samples),
}
except Exception as exc: # noqa: PERF203
task.error = exc
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
for task in batch:
if task.result_prompt_semantic is not None:
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
self._notify_task_done(task)
def _run_loop(self) -> None:
while True:
batch, batch_collected_at = self._collect_batch()
try:
self._run_batch(batch, batch_collected_at)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
self._notify_task_done(task)
finally:
self._finalize_batch(batch)

View File

@ -0,0 +1,215 @@
import asyncio
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Deque, Dict, Tuple
@dataclass
class TextCpuTask:
text: str
language: str
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
enqueued_at: float = 0.0
admission_wait_ms: float = 0.0
backpressure_wait_ms: float = 0.0
capacity_wait_ms: float = 0.0
pending_depth_on_enqueue: int = 0
done_event: threading.Event = field(default_factory=threading.Event)
done_loop: asyncio.AbstractEventLoop | None = None
done_future: asyncio.Future | None = None
result: Any = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareTextCpuWorker:
def __init__(
self,
process_fn: Callable[[str, str], Any],
worker_count: int,
max_pending_tasks: int = 0,
admission_poll_ms: int = 1,
admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None,
) -> None:
self.process_fn = process_fn
self.worker_count = max(1, int(worker_count))
self.max_pending_tasks = max(0, int(max_pending_tasks))
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
self.admission_controller = admission_controller
self.condition = threading.Condition()
self.pending_tasks: Deque[TextCpuTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.active_workers = 0
self.active_workers_peak = 0
self.admission_wait_total_ms = 0.0
self.admission_wait_peak_ms = 0.0
self.backpressure_wait_total_ms = 0.0
self.backpressure_wait_peak_ms = 0.0
self.capacity_wait_total_ms = 0.0
self.capacity_wait_peak_ms = 0.0
self.backpressure_blocked_total = 0
self.worker_threads = [
threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True)
for index in range(self.worker_count)
]
for thread in self.worker_threads:
thread.start()
def _can_enqueue_locked(self) -> bool:
if self.max_pending_tasks <= 0:
return True
return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks
def _get_admission_state(self) -> Dict[str, float | int | bool]:
if self.admission_controller is None:
return {"blocked": False}
try:
state = dict(self.admission_controller() or {})
except Exception:
return {"blocked": False}
state["blocked"] = bool(state.get("blocked", False))
return state
def _record_enqueue_locked(
self,
task: TextCpuTask,
*,
admission_wait_ms: float,
backpressure_wait_ms: float,
capacity_wait_ms: float,
) -> None:
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms))
task.capacity_wait_ms = float(max(0.0, capacity_wait_ms))
task.enqueued_at = time.perf_counter()
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
self.pending_tasks.append(task)
self.total_submitted += 1
self.admission_wait_total_ms += task.admission_wait_ms
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
self.backpressure_wait_total_ms += task.backpressure_wait_ms
self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms)
self.capacity_wait_total_ms += task.capacity_wait_ms
self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms)
if task.backpressure_wait_ms > 0.0:
self.backpressure_blocked_total += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
async def _enqueue_task_async(self, task: TextCpuTask) -> None:
admission_started = time.perf_counter()
backpressure_wait_ms = 0.0
capacity_wait_ms = 0.0
while True:
loop_start = time.perf_counter()
admission_state = self._get_admission_state()
blocked = bool(admission_state.get("blocked", False))
with self.condition:
if not blocked and self._can_enqueue_locked():
self._record_enqueue_locked(
task,
admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0,
backpressure_wait_ms=backpressure_wait_ms,
capacity_wait_ms=capacity_wait_ms,
)
return
await asyncio.sleep(self.admission_poll_s)
waited_ms = (time.perf_counter() - loop_start) * 1000.0
if blocked:
backpressure_wait_ms += waited_ms
else:
capacity_wait_ms += waited_ms
def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
task = TextCpuTask(text=str(text), language=str(language))
asyncio.run(self._enqueue_task_async(task))
task.done_event.wait()
if task.error is not None:
raise task.error
return task.result, dict(task.profile)
async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
loop = asyncio.get_running_loop()
task = TextCpuTask(
text=str(text),
language=str(language),
done_loop=loop,
done_future=loop.create_future(),
)
await self._enqueue_task_async(task)
return await task.done_future
@staticmethod
def _resolve_done_future(task: TextCpuTask) -> None:
if task.done_future is None or task.done_future.done():
return
if task.error is not None:
task.done_future.set_exception(task.error)
return
task.done_future.set_result((task.result, dict(task.profile)))
def _notify_task_done(self, task: TextCpuTask) -> None:
task.done_event.set()
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
except RuntimeError:
pass
def snapshot(self) -> Dict[str, int | float]:
with self.condition:
return {
"worker_count": int(self.worker_count),
"pending": int(len(self.pending_tasks)),
"pending_peak": int(self.pending_peak),
"active_workers": int(self.active_workers),
"active_workers_peak": int(self.active_workers_peak),
"total_submitted": int(self.total_submitted),
"total_finished": int(self.total_finished),
"max_pending_tasks": int(self.max_pending_tasks),
"admission_wait_total_ms": float(self.admission_wait_total_ms),
"admission_wait_peak_ms": float(self.admission_wait_peak_ms),
"backpressure_wait_total_ms": float(self.backpressure_wait_total_ms),
"backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms),
"capacity_wait_total_ms": float(self.capacity_wait_total_ms),
"capacity_wait_peak_ms": float(self.capacity_wait_peak_ms),
"backpressure_blocked_total": int(self.backpressure_blocked_total),
}
def _run_loop(self) -> None:
while True:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
task = self.pending_tasks.popleft()
self.active_workers += 1
self.active_workers_peak = max(self.active_workers_peak, self.active_workers)
started_at = time.perf_counter()
try:
task.result = self.process_fn(task.text, task.language)
task.profile = {
"text_cpu_admission_wait_ms": float(task.admission_wait_ms),
"text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms),
"text_cpu_capacity_wait_ms": float(task.capacity_wait_ms),
"text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0),
"text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
"text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0),
}
except Exception as exc: # noqa: PERF203
task.error = exc
finally:
with self.condition:
self.active_workers = max(0, self.active_workers - 1)
self.total_finished += 1
self.condition.notify_all()
self._notify_task_done(task)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,112 @@
import os
import re
import sys
from typing import Dict, List, Optional, Tuple
now_dir = os.getcwd()
sys.path.append(now_dir)
from text.LangSegmenter import LangSegmenter
from text import cleaned_text_to_sequence
from text import chinese2
from text.cleaner import clean_text
PreparedTextSegmentPayload = Dict[str, object]
def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]:
textlist: List[str] = []
langlist: List[str] = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text, "zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
)
if same_group:
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
langlist.append(language)
textlist.append(tmp["text"])
return textlist, langlist
def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]:
normalized_language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, normalized_language, version)
phones = cleaned_text_to_sequence(phones, version)
return list(phones), None if word2ph is None else list(word2ph), str(norm_text)
def preprocess_text_segments_payload(
text: str,
language: str,
version: str,
final: bool = False,
) -> List[PreparedTextSegmentPayload]:
text = re.sub(r" {2,}", " ", text)
textlist, langlist = split_text_by_language(text, language)
payloads: List[PreparedTextSegmentPayload] = []
total_phones_len = 0
for segment_text, segment_lang in zip(textlist, langlist):
normalized_language = segment_lang.replace("all_", "")
if normalized_language == "zh":
norm_text = chinese2.text_normalize(segment_text)
phones = []
word2ph = None
needs_g2pw = True
estimated_phones_len = max(0, len(norm_text) * 2)
else:
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
needs_g2pw = False
estimated_phones_len = len(phones)
payloads.append(
{
"language": normalized_language,
"phones": phones,
"word2ph": word2ph,
"norm_text": norm_text,
"needs_g2pw": needs_g2pw,
}
)
total_phones_len += int(estimated_phones_len)
if not final and total_phones_len < 6:
return preprocess_text_segments_payload("." + text, language, version, final=True)
return payloads

View File

@ -0,0 +1,46 @@
from __future__ import annotations
import os
from typing import Sequence
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks
from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates
from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface
class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates):
@staticmethod
def _env_flag(name: str, default: bool) -> bool:
value = os.environ.get(name)
if value is None:
return bool(default)
return str(value).strip().lower() not in {"0", "false", "no", "off", ""}
@staticmethod
def _env_int(name: str, default: int) -> int:
value = os.environ.get(name)
if value in [None, ""]:
return int(default)
return int(value)
@staticmethod
def _env_float(name: str, default: float) -> float:
value = os.environ.get(name)
if value in [None, ""]:
return float(default)
return float(value)
def __init__(
self,
tts: TTS,
cut_method_names: Sequence[str],
control_callbacks: RuntimeControlCallbacks | None = None,
max_steps: int = 1500,
micro_batch_wait_ms: int = 5,
) -> None:
self.tts = tts
self.cut_method_names = set(cut_method_names)
self.control_callbacks = control_callbacks or RuntimeControlCallbacks()
EngineCompositionBuilder(self).build(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms)

View File

@ -0,0 +1,451 @@
from __future__ import annotations
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_direct import EngineApiDirectFlow
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_profile import (
aggregate_numeric_dicts,
build_direct_scheduler_profile,
build_direct_segment_trace,
build_legacy_direct_profile,
build_request_meta,
build_scheduler_debug_batch_profile,
build_scheduler_debug_request_profile,
build_scheduler_submit_headers,
build_scheduler_submit_profile,
format_ms_header,
sum_profile_field,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_request import (
apply_default_reference,
base_request_defaults,
check_params,
is_aux_ref_enabled,
normalize_engine_request,
normalize_lang,
normalize_streaming_mode,
select_direct_backend,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_scheduler import EngineApiSchedulerFlow
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
DirectTTSExecution,
NormalizedEngineRequest,
SchedulerDebugExecution,
SchedulerSubmitExecution,
)
class EngineApiFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
self.direct_flow = EngineApiDirectFlow(self)
self.scheduler_flow = EngineApiSchedulerFlow(self)
@property
def tts(self):
return self.owner.tts
@property
def cut_method_names(self):
return self.owner.cut_method_names
@property
def reference_registry(self):
return self.owner.reference_registry
@property
def direct_tts_lock(self):
return self.owner.direct_tts_lock
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
):
return self.owner._register_request_state(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(
self,
request_id: str,
status: str,
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.owner._update_request_state(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.owner._merge_request_state_profile(request_id, extra)
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.owner._complete_request_state(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.owner._fail_request_state(request_id, error)
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.owner._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
super_sampling: bool,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
):
return await self.owner._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
return self.owner.request_registry.collect_summaries(request_ids)
def _has_active_request(self, request_id: str) -> bool:
return self.owner.request_registry.has_active(request_id)
@staticmethod
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
return build_request_meta(payload)
@staticmethod
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
return sum_profile_field(items, key)
def _build_direct_segment_trace(
self,
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
) -> List[Dict[str, Any]]:
return build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
def _build_direct_scheduler_profile(
self,
*,
backend: str,
request_start: float,
response_ready_at: float,
audio_bytes: int,
sample_rate: int,
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
pack_ms: float,
response_overhead_ms: float,
) -> Dict[str, Any]:
return build_direct_scheduler_profile(
backend=backend,
request_start=request_start,
response_ready_at=response_ready_at,
audio_bytes=audio_bytes,
sample_rate=sample_rate,
segment_texts=segment_texts,
prepare_profiles=prepare_profiles,
worker_profiles=worker_profiles,
pack_ms=pack_ms,
response_overhead_ms=response_overhead_ms,
)
def _build_legacy_direct_profile(
self,
*,
backend: str,
fallback_reason: str | None,
request_start: float,
finished_at: float,
sample_rate: int | None = None,
audio_bytes: int = 0,
pack_ms: float = 0.0,
chunk_count: int = 0,
stream_total_bytes: int = 0,
first_chunk_ms: float | None = None,
) -> Dict[str, Any]:
return build_legacy_direct_profile(
backend=backend,
fallback_reason=fallback_reason,
request_start=request_start,
finished_at=finished_at,
sample_rate=sample_rate,
audio_bytes=audio_bytes,
pack_ms=pack_ms,
chunk_count=chunk_count,
stream_total_bytes=stream_total_bytes,
first_chunk_ms=first_chunk_ms,
)
def _build_scheduler_submit_profile(
self,
*,
backend: str,
request_start: float,
response_ready_at: float,
audio_bytes: int,
sample_rate: int,
prepare_spec_build_ms: float,
prepare_wall_ms: float,
prepare_executor_queue_ms: float,
prepare_executor_run_ms: float,
prepare_profile_total_ms: float,
prepare_profile_wall_ms: float,
prepare_other_ms: float,
engine_policy_wait_ms: float,
api_after_prepare_ms: float,
api_wait_result_ms: float,
pack_ms: float,
response_overhead_ms: float,
worker_profile: Dict[str, Any],
) -> Dict[str, Any]:
return build_scheduler_submit_profile(
backend=backend,
request_start=request_start,
response_ready_at=response_ready_at,
audio_bytes=audio_bytes,
sample_rate=sample_rate,
prepare_spec_build_ms=prepare_spec_build_ms,
prepare_wall_ms=prepare_wall_ms,
prepare_executor_queue_ms=prepare_executor_queue_ms,
prepare_executor_run_ms=prepare_executor_run_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
prepare_profile_wall_ms=prepare_profile_wall_ms,
prepare_other_ms=prepare_other_ms,
engine_policy_wait_ms=engine_policy_wait_ms,
api_after_prepare_ms=api_after_prepare_ms,
api_wait_result_ms=api_wait_result_ms,
pack_ms=pack_ms,
response_overhead_ms=response_overhead_ms,
worker_profile=worker_profile,
)
@staticmethod
def _format_ms_header(value: Any) -> str:
return format_ms_header(value)
def _build_scheduler_submit_headers(
self,
*,
request_id: str,
media_type: str,
sample_rate: int,
profile: Dict[str, Any],
) -> Dict[str, str]:
return build_scheduler_submit_headers(
request_id=request_id,
media_type=media_type,
sample_rate=sample_rate,
profile=profile,
)
def _build_scheduler_debug_request_profile(
self,
*,
state: T2SRequestState,
item: T2SFinishedItem,
batch_request_count: int,
prepare_batch_wall_ms: float,
decode_batch_wall_ms: float,
batch_request_total_ms: float,
) -> Dict[str, Any]:
return build_scheduler_debug_request_profile(
state=state,
item=item,
batch_request_count=batch_request_count,
prepare_batch_wall_ms=prepare_batch_wall_ms,
decode_batch_wall_ms=decode_batch_wall_ms,
batch_request_total_ms=batch_request_total_ms,
)
@staticmethod
def _build_scheduler_debug_batch_profile(
*,
request_count: int,
max_steps: int,
prepare_batch_wall_ms: float,
decode_batch_wall_ms: float,
request_total_ms: float,
finished_items: Sequence[T2SFinishedItem],
) -> Dict[str, Any]:
return build_scheduler_debug_batch_profile(
request_count=request_count,
max_steps=max_steps,
prepare_batch_wall_ms=prepare_batch_wall_ms,
decode_batch_wall_ms=decode_batch_wall_ms,
request_total_ms=request_total_ms,
finished_items=finished_items,
)
def _normalize_lang(self, value: str | None) -> str | None:
return normalize_lang(value)
@staticmethod
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
return aggregate_numeric_dicts(items)
def _apply_default_reference(self, req: dict) -> dict:
return apply_default_reference(self.reference_registry, req)
def check_params(self, req: dict) -> Optional[str]:
return check_params(self.tts, self.cut_method_names, req)
@staticmethod
def _base_request_defaults() -> Dict[str, Any]:
return base_request_defaults()
def _normalize_engine_request(
self,
payload: dict | NormalizedEngineRequest,
*,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
return normalize_engine_request(
tts=self.tts,
cut_method_names=self.cut_method_names,
reference_registry=self.reference_registry,
payload=payload,
request_id=request_id,
normalize_streaming=normalize_streaming,
error_prefix=error_prefix,
)
@staticmethod
def _normalize_streaming_mode(req: dict) -> dict:
return normalize_streaming_mode(req)
@staticmethod
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return is_aux_ref_enabled(aux_ref_audio_paths)
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
return select_direct_backend(normalized)
def _iter_legacy_direct_tts_bytes(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> Generator[bytes, None, None]:
yield from self.direct_flow._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
return self.direct_flow._should_use_scheduler_backend_for_direct(req)
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
return self.direct_flow._segment_direct_text(normalized)
def _build_segment_request(
self,
normalized: NormalizedEngineRequest,
*,
request_id: str,
text: str,
) -> NormalizedEngineRequest:
return self.direct_flow._build_segment_request(
normalized,
request_id=request_id,
text=text,
)
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
return await self.direct_flow._run_direct_tts_via_scheduler(normalized)
def _run_legacy_direct_tts_blocking(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return self.direct_flow._run_legacy_direct_tts_blocking(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def _run_direct_tts_via_legacy_backend(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return await self.direct_flow._run_direct_tts_via_legacy_backend(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
return await self.direct_flow.run_direct_tts_async(req)
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
return self.direct_flow.run_direct_tts(req)
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
return self.scheduler_flow._build_scheduler_request_specs(request_items)
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
return self.scheduler_flow._build_scheduler_submit_spec(payload)
@staticmethod
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
return EngineApiSchedulerFlow._summarize_scheduler_states(states)
@staticmethod
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
return EngineApiSchedulerFlow._summarize_scheduler_finished(items)
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
return await self.scheduler_flow.run_scheduler_debug(request_items, max_steps, seed)
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
return await self.scheduler_flow.run_scheduler_submit(payload)

View File

@ -0,0 +1,165 @@
from __future__ import annotations
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, NormalizedEngineRequest
class EngineApiDelegates:
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
return self.api_facade._collect_request_summaries(request_ids)
def _has_active_request(self, request_id: str) -> bool:
return self.api_facade._has_active_request(request_id)
@staticmethod
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
return EngineApiFacade._build_request_meta(payload)
@staticmethod
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
return EngineApiFacade._sum_profile_field(items, key)
def _build_direct_segment_trace(
self,
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
) -> List[Dict[str, Any]]:
return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_direct_scheduler_profile(**kwargs)
def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_legacy_direct_profile(**kwargs)
def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_submit_profile(**kwargs)
@staticmethod
def _format_ms_header(value: Any) -> str:
return EngineApiFacade._format_ms_header(value)
def _build_scheduler_submit_headers(
self,
*,
request_id: str,
media_type: str,
sample_rate: int,
profile: Dict[str, Any],
) -> Dict[str, str]:
return self.api_facade._build_scheduler_submit_headers(
request_id=request_id,
media_type=media_type,
sample_rate=sample_rate,
profile=profile,
)
def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]:
return self.api_facade._build_scheduler_debug_request_profile(**kwargs)
@staticmethod
def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]:
return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs)
def _normalize_lang(self, value: str | None) -> str | None:
return self.api_facade._normalize_lang(value)
@staticmethod
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
return EngineApiFacade._aggregate_numeric_dicts(items)
def _apply_default_reference(self, req: dict) -> dict:
return self.api_facade._apply_default_reference(req)
def check_params(self, req: dict) -> Optional[str]:
return self.api_facade.check_params(req)
@staticmethod
def _base_request_defaults() -> Dict[str, Any]:
return EngineApiFacade._base_request_defaults()
def _normalize_engine_request(
self,
payload: dict | NormalizedEngineRequest,
*,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
return self.api_facade._normalize_engine_request(
payload,
request_id=request_id,
normalize_streaming=normalize_streaming,
error_prefix=error_prefix,
)
@staticmethod
def _normalize_streaming_mode(req: dict) -> dict:
return EngineApiFacade._normalize_streaming_mode(req)
@staticmethod
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths)
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
return self.api_facade._select_direct_backend(normalized)
def _iter_legacy_direct_tts_bytes(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> Generator[bytes, None, None]:
return self.api_facade._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
return self.api_facade._should_use_scheduler_backend_for_direct(req)
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
return self.api_facade._segment_direct_text(normalized)
def _build_segment_request(
self,
normalized: NormalizedEngineRequest,
*,
request_id: str,
text: str,
) -> NormalizedEngineRequest:
return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text)
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_scheduler(normalized)
def _run_legacy_direct_tts_blocking(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return self.api_facade._run_legacy_direct_tts_blocking(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def _run_direct_tts_via_legacy_backend(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
return await self.api_facade._run_direct_tts_via_legacy_backend(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)

View File

@ -0,0 +1,595 @@
from __future__ import annotations
import asyncio
import queue
import threading
import time
import uuid
from io import BytesIO
from typing import Any, Dict, Generator, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, wave_header_chunk
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineStatus, NormalizedEngineRequest, SchedulerPendingJob
class EngineApiDirectFlow:
def __init__(self, api: Any) -> None:
self.api = api
def _iter_legacy_direct_tts_bytes(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> Generator[bytes, None, None]:
payload = normalized.to_payload()
media_type = normalized.media_type
request_id = normalized.request_id
request_start = time.perf_counter()
chunk_count = 0
stream_total_bytes = 0
first_chunk_ms: float | None = None
self.api._update_request_state(
request_id,
EngineStatus.ACTIVE_DECODE,
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
)
try:
with self.api.direct_tts_lock:
tts_generator = self.api.tts.run(payload)
first_chunk = True
current_media_type = media_type
for sr, chunk in tts_generator:
if first_chunk:
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
self.api._update_request_state(
request_id,
EngineStatus.STREAMING,
{
"backend": backend,
"backend_mode": backend,
"fallback_reason": fallback_reason,
"sample_rate": int(sr),
},
)
if first_chunk and media_type == "wav":
header = wave_header_chunk(sample_rate=sr)
chunk_count += 1
stream_total_bytes += len(header)
yield header
current_media_type = "raw"
first_chunk = False
elif first_chunk:
first_chunk = False
packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_chunk)
yield packed_chunk
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
raise
self.api._complete_request_state(
request_id,
dict(
self.api._build_legacy_direct_profile(
backend=backend,
fallback_reason=fallback_reason,
request_start=request_start,
finished_at=time.perf_counter(),
audio_bytes=stream_total_bytes,
chunk_count=chunk_count,
stream_total_bytes=stream_total_bytes,
first_chunk_ms=first_chunk_ms,
),
streaming_completed=True,
),
)
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
if isinstance(req, NormalizedEngineRequest):
normalized = req
else:
normalized = self.api._normalize_engine_request(
req,
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
normalize_streaming=True,
)
backend, _ = self.api._select_direct_backend(normalized)
return backend == "scheduler_v1_direct"
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized
return self.api.tts.text_preprocessor.pre_seg_text(
str(payload["text"]),
str(payload["text_lang"]),
str(payload.get("text_split_method", "cut5")),
)
def _build_segment_request(
self,
normalized: NormalizedEngineRequest,
*,
request_id: str,
text: str,
) -> NormalizedEngineRequest:
payload = normalized.to_payload()
payload["request_id"] = request_id
payload["text"] = text
payload["streaming_mode"] = False
payload["return_fragment"] = False
payload["fixed_length_chunk"] = False
payload["response_streaming"] = False
return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ")
async def _execute_single_segment_scheduler_job(
self,
normalized: NormalizedEngineRequest,
*,
segment_request: NormalizedEngineRequest,
) -> tuple[SchedulerPendingJob, Dict[str, Any]]:
spec = self.api._build_scheduler_submit_spec(segment_request)
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=time.perf_counter(),
engine_request_id=None,
)
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms))
loop = asyncio.get_running_loop()
done_future = loop.create_future()
await self.api._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=loop,
done_future=done_future,
engine_request_id=None,
timeout_sec=normalized.timeout_sec,
)
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
job: SchedulerPendingJob = await asyncio.wait_for(done_future, timeout=timeout_sec)
return job, {
"request_id": spec.request_id,
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"prepare_profile": dict(state.prepare_profile),
}
def _iter_scheduler_direct_tts_bytes(self, normalized: NormalizedEngineRequest) -> Generator[bytes, None, None]:
request_start = time.perf_counter()
request_id = normalized.request_id
media_type = normalized.media_type
segment_texts = self._segment_direct_text(normalized)
if not segment_texts:
raise ValueError("text preprocessing returned no valid segments")
chunk_queue: queue.Queue[object] = queue.Queue(maxsize=8)
done_marker = object()
async def _produce_chunks() -> None:
self.api._update_request_state(
request_id,
EngineStatus.CPU_PREPARING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
)
sample_rate: int | None = None
current_media_type = media_type
chunk_count = 0
stream_total_bytes = 0
first_chunk_ms: float | None = None
prepare_profiles: List[Dict[str, Any]] = []
worker_profiles: List[Dict[str, Any]] = []
try:
for segment_index, segment_text in enumerate(segment_texts):
segment_request = self._build_segment_request(
normalized,
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
self.api._update_request_state(
request_id,
EngineStatus.READY_FOR_PREFILL,
{
"backend": "scheduler_v1_direct",
"backend_mode": "scheduler_v1_direct",
"segment_index": segment_index,
"segment_count": len(segment_texts),
},
)
job, prepare_profile = await self._execute_single_segment_scheduler_job(
normalized,
segment_request=segment_request,
)
prepare_profiles.append(prepare_profile)
if job.error is not None:
raise RuntimeError(job.error)
if job.audio_data is None or job.sample_rate is None or job.result is None:
raise RuntimeError(f"{job.request_id} finished without audio result")
worker_profiles.append(dict(job.result))
if sample_rate is None:
sample_rate = int(job.sample_rate)
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
self.api._update_request_state(
request_id,
EngineStatus.STREAMING,
{
"backend": "scheduler_v1_direct",
"backend_mode": "scheduler_v1_direct",
"sample_rate": int(sample_rate),
},
)
if media_type == "wav":
header = wave_header_chunk(sample_rate=int(sample_rate))
chunk_count += 1
stream_total_bytes += len(header)
chunk_queue.put(header)
current_media_type = "raw"
packed_chunk = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), current_media_type).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_chunk)
chunk_queue.put(packed_chunk)
if segment_index + 1 < len(segment_texts):
silence_samples = int(float(normalized.fragment_interval) * float(job.sample_rate))
if silence_samples > 0:
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
packed_silence = pack_audio(
BytesIO(), silence_chunk, int(job.sample_rate), current_media_type
).getvalue()
chunk_count += 1
stream_total_bytes += len(packed_silence)
chunk_queue.put(packed_silence)
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
chunk_queue.put(exc)
else:
self.api._merge_request_state_profile(
request_id,
{
"prepare_aggregate": self.api._aggregate_numeric_dicts(
[item["prepare_profile"] for item in prepare_profiles]
),
"engine_policy_wait_ms": sum(
float(item.get("engine_policy_wait_ms", 0.0)) for item in worker_profiles
),
"engine_dispatch_wait_ms": sum(
float(item.get("engine_dispatch_wait_ms", 0.0)) for item in worker_profiles
),
},
)
direct_profile = self.api._build_direct_scheduler_profile(
backend="scheduler_v1_direct",
request_start=request_start,
response_ready_at=time.perf_counter(),
audio_bytes=stream_total_bytes,
sample_rate=int(sample_rate or 0),
segment_texts=segment_texts,
prepare_profiles=prepare_profiles,
worker_profiles=worker_profiles,
pack_ms=0.0,
response_overhead_ms=0.0,
)
self.api._complete_request_state(
request_id,
dict(direct_profile, streaming_completed=True, first_chunk_ms=first_chunk_ms),
)
finally:
chunk_queue.put(done_marker)
producer_thread = threading.Thread(target=lambda: asyncio.run(_produce_chunks()), daemon=True)
producer_thread.start()
while True:
item = chunk_queue.get()
if item is done_marker:
break
if isinstance(item, Exception):
raise item
yield item
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
request_start = time.perf_counter()
request_id = normalized.request_id
media_type = normalized.media_type
segment_texts = self._segment_direct_text(normalized)
if not segment_texts:
raise ValueError("text preprocessing returned no valid segments")
if normalized.response_streaming:
return DirectTTSExecution(
media_type=media_type,
streaming=True,
audio_generator=self._iter_scheduler_direct_tts_bytes(normalized),
request_id=request_id,
)
self.api._update_request_state(
request_id,
EngineStatus.CPU_PREPARING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
)
segment_requests = [
self._build_segment_request(
normalized,
request_id=f"{request_id}_seg_{segment_index:03d}",
text=segment_text,
)
for segment_index, segment_text in enumerate(segment_texts)
]
prepare_profiles: List[Dict[str, Any]] = []
loop = asyncio.get_running_loop()
done_futures: List[asyncio.Future] = []
self.api._update_request_state(
request_id,
EngineStatus.READY_FOR_PREFILL,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_requests)},
)
prepared_items = await asyncio.gather(
*[
self._execute_single_segment_scheduler_job(
normalized,
segment_request=segment_request,
)
for segment_request in segment_requests
]
)
for job, prepare_profile in prepared_items:
prepare_profiles.append(prepare_profile)
done_future = loop.create_future()
done_future.set_result(job)
done_futures.append(done_future)
self.api._update_request_state(
request_id,
EngineStatus.ACTIVE_DECODE,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
)
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec))
for profile_item, job in zip(prepare_profiles, jobs):
profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms)
profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms)
self.api._merge_request_state_profile(
request_id,
{
"engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs),
"engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs),
"prepare_aggregate": self.api._aggregate_numeric_dicts([item["prepare_profile"] for item in prepare_profiles]),
},
)
sample_rate: int | None = None
audio_parts: List[np.ndarray] = []
worker_profiles: List[Dict[str, Any]] = []
fragment_interval = float(normalized.fragment_interval)
silence_chunk: Optional[np.ndarray] = None
for job in jobs:
if job.error is not None:
raise RuntimeError(job.error)
if job.audio_data is None or job.sample_rate is None or job.result is None:
raise RuntimeError(f"{job.request_id} finished without audio result")
if sample_rate is None:
sample_rate = int(job.sample_rate)
silence_samples = int(fragment_interval * float(sample_rate))
if silence_samples > 0:
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
elif int(job.sample_rate) != sample_rate:
raise RuntimeError("segment sample rate mismatch")
audio_parts.append(job.audio_data)
if silence_chunk is not None:
audio_parts.append(silence_chunk.copy())
worker_profiles.append(dict(job.result))
if sample_rate is None or not audio_parts:
raise RuntimeError("direct scheduler backend produced no audio")
self.api._update_request_state(
request_id,
EngineStatus.FINALIZING,
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
)
merged_audio = np.concatenate(audio_parts, axis=0)
pack_start = time.perf_counter()
audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue()
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
direct_profile = self.api._build_direct_scheduler_profile(
backend="scheduler_v1_direct",
request_start=request_start,
response_ready_at=time.perf_counter(),
audio_bytes=len(audio_bytes),
sample_rate=int(sample_rate),
segment_texts=segment_texts,
prepare_profiles=prepare_profiles,
worker_profiles=worker_profiles,
pack_ms=pack_ms,
response_overhead_ms=0.0,
)
self.api._complete_request_state(
request_id,
dict(direct_profile, streaming_completed=False),
)
return DirectTTSExecution(
media_type=media_type,
streaming=False,
audio_bytes=audio_bytes,
request_id=request_id,
)
def _run_legacy_direct_tts_blocking(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
normalized_payload = normalized.to_payload()
request_id = normalized.request_id
media_type = normalized.media_type
request_start = time.perf_counter()
self.api._update_request_state(
request_id,
EngineStatus.ACTIVE_DECODE,
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
)
with self.api.direct_tts_lock:
tts_generator = self.api.tts.run(normalized_payload)
try:
sr, audio_data = next(tts_generator)
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
raise
self.api._update_request_state(
request_id,
EngineStatus.FINALIZING,
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
)
pack_start = time.perf_counter()
packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
self.api._complete_request_state(
request_id,
dict(
self.api._build_legacy_direct_profile(
backend=backend,
fallback_reason=fallback_reason,
request_start=request_start,
finished_at=time.perf_counter(),
sample_rate=int(sr),
audio_bytes=len(packed_audio),
pack_ms=pack_ms,
),
streaming_completed=False,
),
)
return DirectTTSExecution(
media_type=media_type,
streaming=False,
audio_bytes=packed_audio,
request_id=request_id,
)
async def _run_direct_tts_via_legacy_backend(
self,
normalized: NormalizedEngineRequest,
*,
backend: str,
fallback_reason: str | None,
) -> DirectTTSExecution:
if normalized.response_streaming:
return DirectTTSExecution(
media_type=normalized.media_type,
streaming=True,
audio_generator=self._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
),
request_id=normalized.request_id,
)
return await asyncio.to_thread(
self._run_legacy_direct_tts_blocking,
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
normalized = self.api._normalize_engine_request(
req,
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
normalize_streaming=True,
error_prefix="",
)
request_id = normalized.request_id
media_type = normalized.media_type
backend, fallback_reason = self.api._select_direct_backend(normalized)
self.api._register_request_state(
request_id=request_id,
api_mode="tts",
backend=backend,
media_type=media_type,
response_streaming=bool(normalized.response_streaming),
deadline_ts=(time.perf_counter() + float(normalized.timeout_sec) if normalized.timeout_sec is not None else None),
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(
request_id,
EngineStatus.VALIDATED,
{
"request_source": "direct_tts",
"selected_backend": backend,
"fallback_reason": fallback_reason,
},
)
if backend == "scheduler_v1_direct":
try:
return await self._run_direct_tts_via_scheduler(normalized)
except Exception as exc:
self.api._fail_request_state(request_id, str(exc))
raise
return await self._run_direct_tts_via_legacy_backend(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
normalized = self.api._normalize_engine_request(
req,
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
normalize_streaming=True,
error_prefix="",
)
request_id = normalized.request_id
media_type = normalized.media_type
backend, fallback_reason = self.api._select_direct_backend(normalized)
if not self.api._has_active_request(request_id):
self.api._register_request_state(
request_id=request_id,
api_mode="tts",
backend=backend,
media_type=media_type,
response_streaming=bool(normalized.response_streaming),
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(
request_id,
EngineStatus.VALIDATED,
{
"request_source": "direct_tts",
"selected_backend": backend,
"fallback_reason": fallback_reason,
},
)
if backend != "scheduler_v1_direct":
if normalized.response_streaming:
return DirectTTSExecution(
media_type=media_type,
streaming=True,
audio_generator=self._iter_legacy_direct_tts_bytes(
normalized,
backend=backend,
fallback_reason=fallback_reason,
),
request_id=request_id,
)
return self._run_legacy_direct_tts_blocking(
normalized,
backend=backend,
fallback_reason=fallback_reason,
)
if normalized.response_streaming:
return DirectTTSExecution(
media_type=media_type,
streaming=True,
audio_generator=self._iter_legacy_direct_tts_bytes(
normalized,
backend="legacy_direct_sync_compat",
fallback_reason="sync_direct_compat",
),
request_id=request_id,
)
return self._run_legacy_direct_tts_blocking(
normalized,
backend="legacy_direct_sync_compat",
fallback_reason="sync_direct_compat",
)

View File

@ -0,0 +1,388 @@
from __future__ import annotations
from typing import Any, Dict, List, Sequence
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
def build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
text = payload.get("text")
prompt_text = payload.get("prompt_text")
return {
"text_len": 0 if text is None else len(str(text)),
"prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)),
"text_lang": payload.get("text_lang"),
"prompt_lang": payload.get("prompt_lang"),
"ref_audio_path": payload.get("ref_audio_path"),
}
def sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
total = 0.0
for item in items:
value = item.get(key, 0.0)
if isinstance(value, (int, float)):
total += float(value)
return total
def aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
totals: Dict[str, float] = {}
for item in items:
for key, value in item.items():
if isinstance(value, (int, float)):
totals[key] = totals.get(key, 0.0) + float(value)
return totals
def build_direct_segment_trace(
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
for index, segment_text in enumerate(segment_texts):
prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {}
worker_item = worker_profiles[index] if index < len(worker_profiles) else {}
prepare_profile = dict(prepare_item.get("prepare_profile", {}))
results.append(
{
"segment_index": index,
"request_id": prepare_item.get("request_id") or worker_item.get("request_id"),
"text_len": len(str(segment_text)),
"prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)),
"prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)),
"prepare_engine_gpu_queue_wait_ms": float(
dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0)
),
"engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)),
"engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)),
"decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)),
"queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)),
"prefill_ms": float(worker_item.get("prefill_ms", 0.0)),
"merge_ms": float(worker_item.get("merge_ms", 0.0)),
"decode_ms": float(worker_item.get("decode_ms", 0.0)),
"finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)),
"synth_ms": float(worker_item.get("synth_ms", 0.0)),
"worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)),
"decode_steps": int(worker_item.get("decode_steps", 0)),
"semantic_len": int(worker_item.get("semantic_len", 0)),
"finish_reason": worker_item.get("finish_reason"),
"norm_text": prepare_profile.get("norm_text"),
}
)
return results
def build_direct_scheduler_profile(
*,
backend: str,
request_start: float,
response_ready_at: float,
audio_bytes: int,
sample_rate: int,
segment_texts: Sequence[str],
prepare_profiles: Sequence[Dict[str, Any]],
worker_profiles: Sequence[Dict[str, Any]],
pack_ms: float,
response_overhead_ms: float,
) -> Dict[str, Any]:
segment_trace = build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles]
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
prepare_wall_ms = sum_profile_field(prepare_profiles, "prepare_wall_ms")
prepare_profile_total_ms = sum_profile_field(prepare_profiles, "prepare_profile_total_ms")
engine_policy_wait_ms = sum_profile_field(prepare_profiles, "engine_policy_wait_ms")
engine_dispatch_wait_ms = sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms")
decode_admission_wait_ms = sum_profile_field(worker_profiles, "decode_admission_wait_ms")
queue_wait_ms = sum_profile_field(worker_profiles, "queue_wait_ms")
prefill_ms = sum_profile_field(worker_profiles, "prefill_ms")
merge_ms = sum_profile_field(worker_profiles, "merge_ms")
decode_ms = sum_profile_field(worker_profiles, "decode_ms")
finalize_wait_ms = sum_profile_field(worker_profiles, "finalize_wait_ms")
synth_ms = sum_profile_field(worker_profiles, "synth_ms")
worker_total_ms = sum_profile_field(worker_profiles, "worker_total_ms")
decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles)
semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles)
request_other_ms = max(
0.0,
request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms,
)
return {
"backend": backend,
"backend_mode": backend,
"segment_count": len(segment_texts),
"sample_rate": int(sample_rate),
"audio_bytes": int(audio_bytes),
"request_total_ms": request_total_ms,
"prepare_ms": prepare_wall_ms,
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"engine_policy_wait_ms": engine_policy_wait_ms,
"engine_dispatch_wait_ms": engine_dispatch_wait_ms,
"decode_admission_wait_ms": decode_admission_wait_ms,
"queue_wait_ms": queue_wait_ms,
"prefill_ms": prefill_ms,
"merge_ms": merge_ms,
"decode_ms": decode_ms,
"finalize_wait_ms": finalize_wait_ms,
"synth_ms": synth_ms,
"pack_ms": pack_ms,
"response_overhead_ms": response_overhead_ms,
"worker_total_ms": worker_total_ms,
"request_other_ms": request_other_ms,
"decode_steps": decode_steps,
"semantic_len": semantic_len,
"prepare_segments": list(prepare_profiles),
"worker_segments": list(worker_profiles),
"segment_trace": segment_trace,
"prepare_aggregate": aggregate_numeric_dicts(prepare_profile_dicts),
}
def build_legacy_direct_profile(
*,
backend: str,
fallback_reason: str | None,
request_start: float,
finished_at: float,
sample_rate: int | None = None,
audio_bytes: int = 0,
pack_ms: float = 0.0,
chunk_count: int = 0,
stream_total_bytes: int = 0,
first_chunk_ms: float | None = None,
) -> Dict[str, Any]:
request_total_ms = max(0.0, (finished_at - request_start) * 1000.0)
legacy_infer_ms = max(0.0, request_total_ms - pack_ms)
return {
"backend": backend,
"backend_mode": backend,
"fallback_reason": fallback_reason,
"request_total_ms": request_total_ms,
"prepare_ms": 0.0,
"queue_wait_ms": 0.0,
"prefill_ms": 0.0,
"merge_ms": 0.0,
"decode_ms": 0.0,
"finalize_wait_ms": 0.0,
"synth_ms": 0.0,
"pack_ms": pack_ms,
"worker_total_ms": legacy_infer_ms,
"request_other_ms": 0.0,
"legacy_infer_ms": legacy_infer_ms,
"sample_rate": int(sample_rate) if sample_rate is not None else None,
"audio_bytes": int(audio_bytes),
"chunk_count": int(chunk_count),
"stream_total_bytes": int(stream_total_bytes),
"first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms),
}
def build_scheduler_submit_profile(
*,
backend: str,
request_start: float,
response_ready_at: float,
audio_bytes: int,
sample_rate: int,
prepare_spec_build_ms: float,
prepare_wall_ms: float,
prepare_executor_queue_ms: float,
prepare_executor_run_ms: float,
prepare_profile_total_ms: float,
prepare_profile_wall_ms: float,
prepare_other_ms: float,
engine_policy_wait_ms: float,
api_after_prepare_ms: float,
api_wait_result_ms: float,
pack_ms: float,
response_overhead_ms: float,
worker_profile: Dict[str, Any],
) -> Dict[str, Any]:
worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0))
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
request_other_ms = max(
0.0,
request_total_ms
- prepare_wall_ms
- engine_policy_wait_ms
- api_after_prepare_ms
- worker_total_ms
- api_wait_result_ms
- pack_ms,
)
result = {
"backend": backend,
"backend_mode": backend,
"audio_bytes": int(audio_bytes),
"sample_rate": int(sample_rate),
"prepare_spec_build_ms": prepare_spec_build_ms,
"prepare_ms": prepare_wall_ms,
"prepare_wall_ms": prepare_wall_ms,
"prepare_executor_queue_ms": prepare_executor_queue_ms,
"prepare_executor_run_ms": prepare_executor_run_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"prepare_profile_wall_ms": prepare_profile_wall_ms,
"prepare_other_ms": prepare_other_ms,
"engine_policy_wait_ms": float(engine_policy_wait_ms),
"api_after_prepare_ms": api_after_prepare_ms,
"api_wait_result_ms": api_wait_result_ms,
"pack_ms": pack_ms,
"response_overhead_ms": response_overhead_ms,
"request_total_ms": request_total_ms,
"request_other_ms": request_other_ms,
}
result.update({key: value for key, value in worker_profile.items()})
return result
def format_ms_header(value: Any) -> str:
return f"{float(value):.3f}"
def build_scheduler_submit_headers(
*,
request_id: str,
media_type: str,
sample_rate: int,
profile: Dict[str, Any],
) -> Dict[str, str]:
prepare_profile = dict(profile.get("prepare_profile", {}))
headers = {
"X-Request-Id": request_id,
"X-Semantic-Len": str(int(profile.get("semantic_len", 0))),
"X-Finish-Reason": str(profile.get("finish_reason", "unknown")),
"X-Queue-Wait-Ms": format_ms_header(profile.get("queue_wait_ms", 0.0)),
"X-Decode-Admission-Wait-Ms": format_ms_header(profile.get("decode_admission_wait_ms", 0.0)),
"X-Engine-Policy-Wait-Ms": format_ms_header(profile.get("engine_policy_wait_ms", 0.0)),
"X-Engine-Dispatch-Wait-Ms": format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)),
"X-Prepare-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)),
"X-Prepare-Wall-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)),
"X-Prepare-Spec-Build-Ms": format_ms_header(profile.get("prepare_spec_build_ms", 0.0)),
"X-Prepare-Executor-Queue-Ms": format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)),
"X-Prepare-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)),
"X-Prepare-Executor-Run-Ms": format_ms_header(profile.get("prepare_executor_run_ms", 0.0)),
"X-Prepare-Profile-Total-Ms": format_ms_header(profile.get("prepare_profile_total_ms", 0.0)),
"X-Prepare-Profile-Wall-Ms": format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)),
"X-Prepare-Other-Ms": format_ms_header(profile.get("prepare_other_ms", 0.0)),
"X-Api-After-Prepare-Ms": format_ms_header(profile.get("api_after_prepare_ms", 0.0)),
"X-Prefill-Ms": format_ms_header(profile.get("prefill_ms", 0.0)),
"X-Merge-Ms": format_ms_header(profile.get("merge_ms", 0.0)),
"X-Decode-Ms": format_ms_header(profile.get("decode_ms", 0.0)),
"X-Finalize-Wait-Ms": format_ms_header(profile.get("finalize_wait_ms", 0.0)),
"X-Synth-Ms": format_ms_header(profile.get("synth_ms", 0.0)),
"X-Worker-Residual-Ms": format_ms_header(profile.get("worker_residual_ms", 0.0)),
"X-Worker-Other-Ms": format_ms_header(profile.get("worker_other_ms", 0.0)),
"X-Pack-Ms": format_ms_header(profile.get("pack_ms", 0.0)),
"X-Worker-Total-Ms": format_ms_header(profile.get("worker_total_ms", 0.0)),
"X-Api-Wait-Result-Ms": format_ms_header(profile.get("api_wait_result_ms", 0.0)),
"X-Decode-Steps": str(int(profile.get("decode_steps", 0))),
"X-Sample-Rate": str(int(sample_rate)),
"X-Response-Overhead-Ms": format_ms_header(profile.get("response_overhead_ms", 0.0)),
"X-Request-Other-Ms": format_ms_header(profile.get("request_other_ms", 0.0)),
"X-Request-Total-Ms": format_ms_header(profile.get("request_total_ms", 0.0)),
}
headers.update(
{
"X-Prepare-Prompt-Text-Ms": format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)),
"X-Prepare-Target-Text-Ms": format_ms_header(prepare_profile.get("text_features_ms", 0.0)),
"X-Prepare-Prompt-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)),
"X-Prepare-Target-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)),
"X-Prepare-Prompt-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)),
"X-Prepare-Target-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)),
"X-Prepare-Prompt-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)),
"X-Prepare-Target-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)),
"X-Prepare-Prompt-Bert-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)),
"X-Prepare-Target-Bert-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)),
"X-Prepare-Prompt-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)),
"X-Prepare-Target-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)),
"X-Prepare-Prompt-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)),
"X-Prepare-Target-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)),
"X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)),
"X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)),
"X-Prepare-Prompt-Bert-Forward-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)),
"X-Prepare-Target-Bert-Forward-Ms": format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)),
"X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))),
"X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))),
"X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))),
"X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))),
"X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))),
"X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))),
"X-Prepare-Prompt-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)),
"X-Prepare-Target-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)),
"X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)),
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
"X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)),
"X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))),
"X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)),
"X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)),
"X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)),
"X-Prepare-Prompt-Semantic-Wait-Ms": format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)),
"X-Prepare-Prompt-Semantic-CPU-Ms": format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"X-Prepare-Prompt-Semantic-Forward-Ms": format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)),
"X-Prepare-Ref-Spec-Ms": format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)),
"X-Prepare-Ref-Spec-Wait-Ms": format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)),
"X-Prepare-Ref-Bundle-Ms": format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)),
"X-Prepare-Tensorize-Ms": format_ms_header(prepare_profile.get("tensorize_ms", 0.0)),
"X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))),
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
}
)
return headers
def build_scheduler_debug_request_profile(
*,
state: T2SRequestState,
item: T2SFinishedItem,
batch_request_count: int,
prepare_batch_wall_ms: float,
decode_batch_wall_ms: float,
batch_request_total_ms: float,
) -> Dict[str, Any]:
prepare_profile = dict(state.prepare_profile)
prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0))
return {
"backend": "scheduler_debug",
"backend_mode": "scheduler_debug",
"batch_request_count": int(batch_request_count),
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
"batch_decode_wall_ms": float(decode_batch_wall_ms),
"batch_request_total_ms": float(batch_request_total_ms),
"prepare_ms": prepare_wall_ms,
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)),
"prepare_profile": prepare_profile,
"decode_steps": int(item.finish_idx),
"finish_idx": int(item.finish_idx),
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_reason": item.finish_reason,
"norm_text": state.norm_text,
"norm_prompt_text": state.norm_prompt_text,
}
def build_scheduler_debug_batch_profile(
*,
request_count: int,
max_steps: int,
prepare_batch_wall_ms: float,
decode_batch_wall_ms: float,
request_total_ms: float,
finished_items: Sequence[T2SFinishedItem],
) -> Dict[str, Any]:
finish_reason_counts: Dict[str, int] = {}
total_semantic_len = 0
for item in finished_items:
finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1
total_semantic_len += int(item.semantic_tokens.shape[0])
return {
"request_count": int(request_count),
"max_steps": int(max_steps),
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
"decode_batch_wall_ms": float(decode_batch_wall_ms),
"request_total_ms": float(request_total_ms),
"total_semantic_len": int(total_semantic_len),
"finish_reason_counts": finish_reason_counts,
}

View File

@ -0,0 +1,189 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import NormalizedEngineRequest, ReferenceRegistry
def normalize_lang(value: str | None) -> str | None:
if value in [None, ""]:
return value
return str(value).lower()
def apply_default_reference(reference_registry: ReferenceRegistry, req: dict) -> dict:
normalized = dict(req)
default_ref = reference_registry.get_default()
if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]:
normalized["ref_audio_path"] = default_ref.ref_audio_path
if "text_lang" in normalized:
normalized["text_lang"] = normalize_lang(normalized.get("text_lang"))
if "prompt_lang" in normalized:
normalized["prompt_lang"] = normalize_lang(normalized.get("prompt_lang"))
return normalized
def check_params(tts: TTS, cut_method_names: Sequence[str], req: dict) -> Optional[str]:
text = req.get("text", "")
text_lang = req.get("text_lang", "")
ref_audio_path = req.get("ref_audio_path", "")
media_type = req.get("media_type", "wav")
prompt_lang = req.get("prompt_lang", "")
text_split_method = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return "ref_audio_path is required"
if text in [None, ""]:
return "text is required"
if text_lang in [None, ""]:
return "text_lang is required"
if text_lang.lower() not in tts.configs.languages:
return f"text_lang: {text_lang} is not supported in version {tts.configs.version}"
if prompt_lang in [None, ""]:
return "prompt_lang is required"
if prompt_lang.lower() not in tts.configs.languages:
return f"prompt_lang: {prompt_lang} is not supported in version {tts.configs.version}"
if media_type not in ["wav", "raw", "ogg", "aac"]:
return f"media_type: {media_type} is not supported"
if text_split_method not in cut_method_names:
return f"text_split_method:{text_split_method} is not supported"
return None
def base_request_defaults() -> Dict[str, Any]:
return {
"request_id": None,
"text": None,
"text_lang": None,
"ref_audio_path": None,
"aux_ref_audio_paths": None,
"prompt_text": "",
"prompt_lang": None,
"top_k": 15,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut5",
"batch_size": 1,
"batch_threshold": 0.75,
"speed_factor": 1.0,
"split_bucket": False,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"return_fragment": False,
"fixed_length_chunk": False,
"response_streaming": False,
"parallel_infer": False,
"repetition_penalty": 1.35,
"sample_steps": 32,
"super_sampling": False,
"overlap_length": 2,
"min_chunk_length": 16,
"early_stop_num": -1,
"ready_step": 0,
"timeout_sec": None,
}
def normalize_streaming_mode(req: dict) -> dict:
normalized = dict(req)
streaming_mode = normalized.get("streaming_mode", False)
return_fragment = normalized.get("return_fragment", False)
if streaming_mode is False:
normalized["streaming_mode"] = False
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 0:
normalized["streaming_mode"] = False
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 1 or streaming_mode is True:
normalized["streaming_mode"] = False
normalized["return_fragment"] = True
normalized["fixed_length_chunk"] = False
elif streaming_mode == 2:
normalized["streaming_mode"] = True
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = False
elif streaming_mode == 3:
normalized["streaming_mode"] = True
normalized["return_fragment"] = False
normalized["fixed_length_chunk"] = True
else:
raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)")
normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment)
return normalized
def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
return aux_ref_audio_paths not in [None, [], ()]
def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
return "scheduler_v1_direct", None
def normalize_engine_request(
*,
tts: TTS,
cut_method_names: Sequence[str],
reference_registry: ReferenceRegistry,
payload: dict | NormalizedEngineRequest,
request_id: str | None = None,
normalize_streaming: bool = False,
error_prefix: str = "request 参数非法: ",
) -> NormalizedEngineRequest:
if isinstance(payload, NormalizedEngineRequest):
normalized_payload = payload.to_payload()
else:
normalized_payload = base_request_defaults()
normalized_payload.update(dict(payload))
if request_id not in [None, ""]:
normalized_payload["request_id"] = str(request_id)
elif normalized_payload.get("request_id") in [None, ""]:
raise ValueError("request_id is required after normalization")
normalized_payload = apply_default_reference(reference_registry, normalized_payload)
if normalize_streaming:
normalized_payload = normalize_streaming_mode(normalized_payload)
error = check_params(tts, cut_method_names, normalized_payload)
if error is not None:
raise ValueError(f"{error_prefix}{error}")
timeout_sec = normalized_payload.get("timeout_sec")
parsed_timeout = None if timeout_sec in [None, ""] else float(timeout_sec)
aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths")
normalized_aux_ref_audio_paths = None if aux_ref_audio_paths in [None, "", []] else [str(item) for item in aux_ref_audio_paths]
return NormalizedEngineRequest(
request_id=str(normalized_payload["request_id"]),
text=str(normalized_payload["text"]),
text_lang=str(normalized_payload["text_lang"]),
ref_audio_path=str(normalized_payload["ref_audio_path"]),
prompt_lang=str(normalized_payload["prompt_lang"]),
prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")),
aux_ref_audio_paths=normalized_aux_ref_audio_paths,
top_k=int(normalized_payload["top_k"]),
top_p=float(normalized_payload["top_p"]),
temperature=float(normalized_payload["temperature"]),
repetition_penalty=float(normalized_payload["repetition_penalty"]),
early_stop_num=int(normalized_payload.get("early_stop_num", -1)),
ready_step=int(normalized_payload.get("ready_step", 0)),
text_split_method=str(normalized_payload["text_split_method"]),
batch_size=int(normalized_payload["batch_size"]),
batch_threshold=float(normalized_payload["batch_threshold"]),
split_bucket=bool(normalized_payload["split_bucket"]),
speed_factor=float(normalized_payload["speed_factor"]),
fragment_interval=float(normalized_payload["fragment_interval"]),
seed=int(normalized_payload["seed"]),
media_type=str(normalized_payload["media_type"]),
streaming_mode=normalized_payload["streaming_mode"],
return_fragment=bool(normalized_payload.get("return_fragment", False)),
fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)),
response_streaming=bool(normalized_payload.get("response_streaming", False)),
parallel_infer=bool(normalized_payload["parallel_infer"]),
sample_steps=int(normalized_payload["sample_steps"]),
super_sampling=bool(normalized_payload["super_sampling"]),
overlap_length=int(normalized_payload["overlap_length"]),
min_chunk_length=int(normalized_payload["min_chunk_length"]),
timeout_sec=parsed_timeout,
)

View File

@ -0,0 +1,340 @@
from __future__ import annotations
import asyncio
import time
import uuid
from io import BytesIO
from typing import Any, Dict, List
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution
class EngineApiSchedulerFlow:
def __init__(self, api: Any) -> None:
self.api = api
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
specs: List[SchedulerRequestSpec] = []
for index, payload in enumerate(request_items):
normalized = self.api._normalize_engine_request(
payload,
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
error_prefix=f"request[{index}] 参数非法: ",
)
specs.append(normalized.to_scheduler_spec())
return specs
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
normalized = self.api._normalize_engine_request(
payload,
request_id=(
payload.request_id
if isinstance(payload, NormalizedEngineRequest)
else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}")
),
)
return normalized.to_scheduler_spec()
@staticmethod
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
return [
{
"request_id": state.request_id,
"ready_step": int(state.ready_step),
"ref_audio_path": str(state.ref_audio_path),
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"all_phone_len": int(state.all_phones.shape[0]),
"bert_len": int(state.all_bert_features.shape[-1]),
"norm_text": state.norm_text,
}
for state in states
]
@staticmethod
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
return [
{
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
}
for item in items
]
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
request_start = time.perf_counter()
set_scheduler_seed(seed)
normalized_requests: List[NormalizedEngineRequest] = []
for index, payload in enumerate(request_items):
normalized_requests.append(
self.api._normalize_engine_request(
payload,
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
error_prefix=f"request[{index}] 参数非法: ",
)
)
specs = [normalized.to_scheduler_spec() for normalized in normalized_requests]
request_ids = [normalized.request_id for normalized in normalized_requests]
for normalized, spec in zip(normalized_requests, specs):
self.api._register_request_state(
request_id=normalized.request_id,
api_mode="scheduler_debug",
backend="scheduler_debug",
media_type=normalized.media_type,
response_streaming=False,
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None)
prepare_started_at = time.perf_counter()
original_worker_max_steps = int(self.api.scheduler_worker.max_steps)
original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps)
try:
self.api.scheduler_worker.max_steps = int(max_steps)
self.api.scheduler_worker.decode_executor.max_steps = int(max_steps)
prepared_payloads = await asyncio.gather(
*[
self.api._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=time.perf_counter(),
engine_request_id=normalized.request_id,
)
for normalized, spec in zip(normalized_requests, specs)
]
)
except Exception as exc:
for request_id in request_ids:
self.api._fail_request_state(request_id, str(exc))
raise
finally:
self.api.scheduler_worker.max_steps = int(original_worker_max_steps)
self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps)
prepare_finished_at = time.perf_counter()
prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0)
states = [payload[0] for payload in prepared_payloads]
for state in states:
self.api._update_request_state(
state.request_id,
EngineStatus.READY_FOR_PREFILL,
{
"prepare_profile": dict(state.prepare_profile),
"norm_text": state.norm_text,
"norm_prompt_text": state.norm_prompt_text,
},
)
decode_started_at = time.perf_counter()
try:
loop = asyncio.get_running_loop()
done_futures: List[asyncio.Future] = []
for normalized, state in zip(normalized_requests, states):
done_future = loop.create_future()
done_futures.append(done_future)
await self.api._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
done_loop=loop,
done_future=done_future,
engine_request_id=normalized.request_id,
timeout_sec=normalized.timeout_sec,
)
timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]]
timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0
jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec)))
except Exception as exc:
for request_id in request_ids:
self.api._fail_request_state(request_id, str(exc))
raise
decode_finished_at = time.perf_counter()
decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0)
request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0)
request_profiles: List[Dict[str, Any]] = []
finished: List[Dict[str, Any]] = []
finish_reason_counts: Dict[str, int] = {}
total_semantic_len = 0
for state, job in zip(states, jobs):
if job.error is not None:
self.api._fail_request_state(state.request_id, str(job.error))
raise RuntimeError(str(job.error))
if job.result is None:
self.api._fail_request_state(state.request_id, "scheduler_debug finished without result")
raise RuntimeError(f"{state.request_id} finished without result")
job_result = dict(job.result)
request_profile = {
**job_result,
"backend": "scheduler_debug",
"backend_mode": "scheduler_debug",
"batch_request_count": int(len(states)),
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
"batch_decode_wall_ms": float(decode_batch_wall_ms),
"batch_request_total_ms": float(request_total_ms),
"prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
"prepare_profile": dict(state.prepare_profile),
"norm_text": state.norm_text,
"norm_prompt_text": state.norm_prompt_text,
}
request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)})
self.api._merge_request_state_profile(state.request_id, request_profile)
semantic_len = int(job_result.get("semantic_len", 0))
finish_reason = str(job_result.get("finish_reason", "unknown"))
finished.append(
{
"request_id": state.request_id,
"semantic_len": semantic_len,
"finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))),
"finish_reason": finish_reason,
}
)
finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1
total_semantic_len += semantic_len
return SchedulerDebugExecution(
payload={
"message": "success",
"request_count": len(states),
"max_steps": int(max_steps),
"batch_profile": {
"request_count": int(len(states)),
"max_steps": int(max_steps),
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
"decode_batch_wall_ms": float(decode_batch_wall_ms),
"request_total_ms": float(request_total_ms),
"total_semantic_len": int(total_semantic_len),
"finish_reason_counts": finish_reason_counts,
},
"requests": self._summarize_scheduler_states(states),
"finished": finished,
"request_profiles": request_profiles,
"request_traces": self.api._collect_request_summaries(request_ids),
}
)
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
request_start = time.perf_counter()
prepare_start = request_start
normalized = self.api._normalize_engine_request(
payload,
request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"),
)
spec = self._build_scheduler_submit_spec(normalized)
deadline_ts = None
timeout_sec = normalized.timeout_sec
if timeout_sec is not None:
try:
deadline_ts = request_start + float(timeout_sec)
except Exception:
deadline_ts = None
self.api._register_request_state(
request_id=spec.request_id,
api_mode="scheduler_submit",
backend="scheduler_v1",
media_type=normalized.media_type,
response_streaming=False,
deadline_ts=deadline_ts,
meta=self.api._build_request_meta(normalized.to_payload()),
)
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"})
spec_ready_at = time.perf_counter()
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms})
try:
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=spec_ready_at,
engine_request_id=spec.request_id,
)
except Exception as exc:
self.api._fail_request_state(spec.request_id, str(exc))
raise
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0)
prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)
prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
prepare_profile = dict(state.prepare_profile)
prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms)
self.api._update_request_state(
spec.request_id,
EngineStatus.READY_FOR_PREFILL,
{
"prepare_wall_ms": prepare_wall_ms,
"prepare_profile_total_ms": prepare_profile_total_ms,
"prepare_profile": prepare_profile,
},
)
api_after_prepare_start = time.perf_counter()
loop = asyncio.get_running_loop()
done_future = loop.create_future()
await self.api._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=float(normalized.speed_factor),
sample_steps=int(normalized.sample_steps),
media_type=normalized.media_type,
super_sampling=bool(normalized.super_sampling),
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=loop,
done_future=done_future,
engine_request_id=spec.request_id,
timeout_sec=normalized.timeout_sec,
)
api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0)
try:
job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0))
except Exception as exc:
self.api._fail_request_state(spec.request_id, str(exc))
raise
wait_return_at = time.perf_counter()
if job.error is not None:
raise RuntimeError(job.error)
if job.audio_data is None or job.sample_rate is None or job.result is None:
self.api._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result")
raise RuntimeError(f"{job.request_id} finished without audio result")
pack_start = time.perf_counter()
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
pack_end = time.perf_counter()
pack_ms = (pack_end - pack_start) * 1000.0
api_wait_result_ms = 0.0
if job.result_ready_time is not None:
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
response_ready_at = time.perf_counter()
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
submit_profile = self.api._build_scheduler_submit_profile(
backend="scheduler_v1",
request_start=request_start,
response_ready_at=response_ready_at,
audio_bytes=len(audio_data),
sample_rate=int(job.sample_rate),
prepare_spec_build_ms=prepare_spec_build_ms,
prepare_wall_ms=prepare_wall_ms,
prepare_executor_queue_ms=prepare_executor_queue_ms,
prepare_executor_run_ms=prepare_executor_run_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
prepare_profile_wall_ms=prepare_profile_wall_ms,
prepare_other_ms=prepare_other_ms,
engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)),
api_after_prepare_ms=api_after_prepare_ms,
api_wait_result_ms=api_wait_result_ms,
pack_ms=pack_ms,
response_overhead_ms=response_overhead_ms,
worker_profile=dict(job.result or {}),
)
headers = self.api._build_scheduler_submit_headers(
request_id=job.request_id,
media_type=job.media_type,
sample_rate=int(job.sample_rate),
profile=submit_profile,
)
self.api._merge_request_state_profile(
spec.request_id,
dict(submit_profile, response_headers_emitted=True),
)
return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=str(job.media_type), headers=headers)

View File

@ -0,0 +1,106 @@
from __future__ import annotations
import subprocess
import threading
import wave
from io import BytesIO
import numpy as np
import soundfile as sf
import torch
def set_scheduler_seed(seed: int):
if seed in ["", None]:
return
seed = int(seed)
if seed < 0:
return
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
def handle_pack_ogg():
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
stack_size = 4096 * 4096
try:
threading.stack_size(stack_size)
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
pack_ogg_thread.start()
pack_ogg_thread.join()
except (RuntimeError, ValueError):
handle_pack_ogg()
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format="wav")
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le",
"-ar",
str(rate),
"-ac",
"1",
"-i",
"pipe:0",
"-c:a",
"aac",
"-b:a",
"192k",
"-vn",
"-f",
"adts",
"pipe:1",
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()

View File

@ -0,0 +1,21 @@
from __future__ import annotations
from typing import Any
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_registry import EngineRegistryBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_runtime import EngineRuntimeBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_stage import EngineStageBridgeFacade
class EngineBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
self.registry_bridge = EngineRegistryBridgeFacade(owner)
self.stage_bridge = EngineStageBridgeFacade(owner)
self.runtime_bridge = EngineRuntimeBridgeFacade(owner)
def __getattr__(self, name: str) -> Any:
for component in (self.registry_bridge, self.stage_bridge, self.runtime_bridge):
if hasattr(component, name):
return getattr(component, name)
raise AttributeError(name)

View File

@ -0,0 +1,202 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineRequestState, SchedulerFinalizeTask, SchedulerPendingJob
class EngineBridgeDelegates:
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.bridge_facade._register_request_state(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._update_request_state(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._merge_request_state_profile(request_id, extra)
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_prepare_state()
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_finalize_state()
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_dispatch_state()
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._register_engine_job(job)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._get_engine_job(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.bridge_facade._pop_engine_job(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_job_registry()
def _is_engine_drained(self) -> bool:
return self.bridge_facade._is_engine_drained()
def _record_engine_job_done(self, request_id: str) -> None:
self.bridge_facade._record_engine_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
self.bridge_facade._fail_engine_jobs(request_ids, error)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s)
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s)
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
self.bridge_facade._enqueue_engine_finished_items(items)
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_pending_queue_state()
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineBridgeFacade._summarize_active_batch(active_batch)
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.bridge_facade._refresh_engine_decode_runtime_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
self.bridge_facade._update_engine_decode_runtime_state(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_decode_runtime_state()
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_engine_arbiter_state()
def _notify_engine_arbiter(self) -> None:
self.bridge_facade._notify_engine_arbiter()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.bridge_facade._enqueue_engine_decode_pending_job(job)
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.bridge_facade._peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.bridge_facade._engine_has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.bridge_facade._prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.bridge_facade._enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.bridge_facade._take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
super_sampling: bool,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.bridge_facade._enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
return self.bridge_facade._select_engine_stage()
def _run_engine_prepare_once(self) -> bool:
return self.bridge_facade._run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.bridge_facade._run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.bridge_facade._run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
self.bridge_facade._run_engine_arbiter_loop()
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.bridge_facade._complete_request_state(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.bridge_facade._fail_request_state(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.bridge_facade._snapshot_request_registry()

View File

@ -0,0 +1,231 @@
from __future__ import annotations
import time
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class EngineRegistryBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def request_registry(self):
return self.owner.request_registry
@property
def engine_prepare_queue_owner(self):
return self.owner.engine_prepare_queue_owner
@property
def engine_prepare_text_queue_owner(self):
return self.owner.engine_prepare_text_queue_owner
@property
def engine_prepare_ref_spec_queue_owner(self):
return self.owner.engine_prepare_ref_spec_queue_owner
@property
def engine_finalize_queue_owner(self):
return self.owner.engine_finalize_queue_owner
@property
def engine_dispatch_queue_owner(self):
return self.owner.engine_dispatch_queue_owner
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def engine_job_registry(self):
return self.owner.engine_job_registry
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
def _register_request_state(
self,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
return self.request_registry.register(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=response_streaming,
deadline_ts=deadline_ts,
meta=meta,
)
def _update_request_state(
self,
request_id: str,
status: str,
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.request_registry.update(request_id, status, extra)
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.merge_profile(request_id, extra)
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
self.request_registry.complete(request_id, extra)
def _fail_request_state(self, request_id: str, error: str) -> None:
self.request_registry.fail(request_id, error)
def _snapshot_request_registry(self) -> Dict[str, Any]:
return self.request_registry.snapshot()
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16)
ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16)
return {
"waiting_count": int(audio_snapshot.get("waiting_count", 0))
+ int(text_snapshot.get("waiting_count", 0))
+ int(ref_spec_snapshot.get("waiting_count", 0)),
"audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)),
"text_waiting_count": int(text_snapshot.get("waiting_count", 0)),
"ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)),
"audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])),
"text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])),
"ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])),
"peak_waiting": int(
max(
int(audio_snapshot.get("peak_waiting", 0)),
int(text_snapshot.get("peak_waiting", 0)),
int(ref_spec_snapshot.get("peak_waiting", 0)),
)
),
"total_submitted": int(audio_snapshot.get("total_submitted", 0)),
"total_completed": int(audio_snapshot.get("total_completed", 0)),
"text_total_submitted": int(text_snapshot.get("total_submitted", 0)),
"text_total_completed": int(text_snapshot.get("total_completed", 0)),
"ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)),
"ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)),
}
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
return self.engine_dispatch_queue_owner.snapshot(
max_request_ids=16,
extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})},
)
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
self.engine_job_registry.register(job, keep_job=True)
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.get(request_id)
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
return self.engine_job_registry.pop(request_id)
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
return self.engine_job_registry.snapshot(max_request_ids=32)
def _is_engine_drained(self) -> bool:
prepare_empty = self.engine_prepare_queue_owner.is_drained()
prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained()
prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained()
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
finalize_empty = self.engine_finalize_queue_owner.is_drained()
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
job_empty = self.engine_job_registry.is_empty()
worker_state = self.scheduler_worker.snapshot()
return bool(
prepare_empty
and prepare_text_empty
and prepare_ref_spec_empty
and dispatch_empty
and finalize_empty
and decode_pending_empty
and job_empty
and self.engine_decode_runtime_owner.get_active_batch() is None
and int(worker_state.get("prepare_inflight", 0)) <= 0
and int(worker_state.get("finalize_inflight", 0)) <= 0
and int(worker_state.get("finalize_pending", 0)) <= 0
)
def _record_engine_job_done(self, request_id: str) -> None:
self.engine_job_registry.mark_finished_and_remove(request_id)
self.scheduler_worker.record_external_job_done(request_id)
def _complete_engine_job(
self,
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
completion_bridge = self.scheduler_worker.completion_bridge
completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
completion_bridge.complete_job(
job,
runtime_request_id=job.engine_request_id,
runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate),
on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid),
)
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
if not request_ids:
return
completion_bridge = self.scheduler_worker.completion_bridge
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
completion_bridge.fail_job(
job,
error=error,
on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid),
)
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for job in jobs:
job.prefill_ms += delta_ms
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is not None:
job.merge_ms += delta_ms
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
activate_request_ids: List[str] = []
for request_id in request_ids:
job = self._get_engine_job(request_id)
if job is None:
continue
if job.decode_steps == 0:
activate_request_ids.append(job.engine_request_id)
job.decode_ms += delta_ms
job.decode_steps += 1
for engine_request_id in activate_request_ids:
self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
if not items:
return
enqueued_at = time.perf_counter()
tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items]
self.owner.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)

View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import Any, Dict
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner
class EngineRuntimeBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def engine_policy_arbiter(self):
return self.owner.engine_policy_arbiter
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch)
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
return self.engine_policy_arbiter.snapshot_state()
def _notify_engine_arbiter(self) -> None:
self.engine_policy_arbiter.notify()
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage()
self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot)
return stage, reason, policy_snapshot, worker_state

View File

@ -0,0 +1,116 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, SchedulerFinalizeTask, SchedulerPendingJob
class EngineStageBridgeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
@property
def engine_stage_coordinator(self):
return self.owner.engine_stage_coordinator
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_pending_queue_state()
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
self.engine_decode_runtime_owner.refresh_state(last_event)
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
if not snapshot:
return
if self.scheduler_worker.is_engine_decode_control_enabled():
return
self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot)
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
return self.engine_decode_runtime_owner.snapshot_state()
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
self.engine_decode_runtime_owner.enqueue_pending_job(job)
self.owner.engine_policy_arbiter.notify()
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch)
def _peek_queue_age_ms(self, queue_name: str) -> float:
return self.engine_stage_coordinator.peek_queue_age_ms(queue_name)
def _engine_has_pending_work(self) -> bool:
return self.engine_stage_coordinator.has_pending_work()
async def _prepare_state_via_engine_gpu_queue(
self,
*,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking()
async def _enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
super_sampling: bool,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def _run_engine_prepare_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_prepare_once()
def _run_engine_finalize_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_finalize_once()
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state)
def _run_engine_decode_runtime_once(self) -> bool:
return self.engine_stage_coordinator.run_engine_decode_runtime_once()
def _run_engine_arbiter_loop(self) -> None:
return self.engine_stage_coordinator.run_engine_arbiter_loop()

View File

@ -0,0 +1,183 @@
from __future__ import annotations
import os
import threading
from typing import Any
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineArbiterConfig,
EngineDecodeRuntimeOwner,
EnginePolicyArbiterController,
EnginePolicyConfig,
EngineRequestRegistry,
EngineTaskQueueOwner,
ModelRegistry,
ReferenceRegistry,
RuntimeStateCallbacks,
SchedulerJobRegistry,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineCompositionBuilder:
def __init__(self, owner: Any) -> None:
self.owner = owner
def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
self._init_registries_and_locks()
self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms)
self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms)
self._init_runtime_owners()
self._init_stage_coordinator()
self._init_arbiter()
self._init_facades()
self._start_arbiter_thread()
def _init_registries_and_locks(self) -> None:
owner = self.owner
owner.reference_registry = ReferenceRegistry()
owner.model_registry = ModelRegistry(
t2s_weights_path=str(owner.tts.configs.t2s_weights_path),
vits_weights_path=str(owner.tts.configs.vits_weights_path),
)
owner.request_registry = EngineRequestRegistry(
recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64")))
)
owner.engine_job_registry = SchedulerJobRegistry(threading.Lock())
owner.direct_tts_lock = threading.RLock()
owner.management_lock = threading.RLock()
owner.engine_dispatch_last_snapshot = {}
def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
owner = self.owner
owner.scheduler_worker = UnifiedSchedulerWorker(
owner.tts,
max_steps=max_steps,
micro_batch_wait_ms=micro_batch_wait_ms,
runtime_callbacks=RuntimeStateCallbacks(
update=owner._update_request_state,
complete=owner._complete_request_state,
fail=owner._fail_request_state,
decode_runtime_update=owner._update_engine_decode_runtime_state,
),
external_finalize_submit=owner._enqueue_worker_finished_for_finalize,
)
def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None:
owner = self.owner
worker_capacity_limits = owner.scheduler_worker.get_capacity_limits()
prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight())
owner.engine_policy_config = EnginePolicyConfig(
enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True),
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))),
decode_backlog_soft_max=max(
0,
owner._env_int(
"GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX",
int(worker_capacity_limits["decode_backlog_max"]),
),
),
finalize_pending_soft_max=max(
0,
owner._env_int(
"GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX",
int(worker_capacity_limits["finalize_pending_max"]),
),
),
prepare_inflight_soft_max=max(
0,
owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight),
),
active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)),
ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)),
active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)),
)
owner.engine_arbiter_config = EngineArbiterConfig(
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))),
decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)),
prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)),
finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)),
)
def _init_runtime_owners(self) -> None:
owner = self.owner
owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner(
get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters,
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
)
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
def _init_stage_coordinator(self) -> None:
owner = self.owner
owner.engine_stage_coordinator = EngineStageCoordinator(
tts=owner.tts,
scheduler_worker=owner.scheduler_worker,
prepare_queue_owner=owner.engine_prepare_queue_owner,
prepare_text_queue_owner=owner.engine_prepare_text_queue_owner,
prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner,
finalize_queue_owner=owner.engine_finalize_queue_owner,
dispatch_queue_owner=owner.engine_dispatch_queue_owner,
decode_runtime_owner=owner.engine_decode_runtime_owner,
update_request_state=owner._update_request_state,
merge_request_state_profile=owner._merge_request_state_profile,
fail_request_state=owner._fail_request_state,
get_engine_job=owner._get_engine_job,
register_engine_job=owner._register_engine_job,
fail_engine_jobs=owner._fail_engine_jobs,
complete_engine_job=owner._complete_engine_job,
add_engine_prefill_time=owner._add_engine_prefill_time,
add_engine_merge_time=owner._add_engine_merge_time,
add_engine_decode_time=owner._add_engine_decode_time,
enqueue_engine_finished_items=owner._enqueue_engine_finished_items,
snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state,
snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
)
def _init_arbiter(self) -> None:
owner = self.owner
owner.engine_policy_arbiter = EnginePolicyArbiterController(
policy_config=owner.engine_policy_config,
arbiter_config=owner.engine_arbiter_config,
snapshot_request_registry=owner._snapshot_request_registry,
get_worker_state=owner.scheduler_worker.snapshot,
snapshot_prepare_state=owner._snapshot_engine_prepare_state,
snapshot_finalize_state=owner._snapshot_engine_finalize_state,
snapshot_dispatch_state=owner._snapshot_engine_dispatch_state,
snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
snapshot_job_registry=owner._snapshot_engine_job_registry,
peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms,
merge_request_state_profile=owner._merge_request_state_profile,
)
owner.engine_stage_coordinator.bind_arbiter(
notify_arbiter=owner._notify_engine_arbiter,
select_stage=owner._select_engine_stage,
mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick(
stage=stage,
reason=reason,
policy_allowed=policy_allowed,
),
wait_arbiter=owner.engine_policy_arbiter.wait,
)
def _init_facades(self) -> None:
owner = self.owner
owner.bridge_facade = EngineBridgeFacade(owner)
owner.api_facade = EngineApiFacade(owner)
owner.runtime_facade = EngineRuntimeFacade(owner)
def _start_arbiter_thread(self) -> None:
owner = self.owner
owner.engine_arbiter_thread = threading.Thread(
target=owner._run_engine_arbiter_loop,
name="unified-engine-arbiter",
daemon=True,
)
owner.engine_arbiter_thread.start()

View File

@ -0,0 +1,121 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec
@dataclass
class RuntimeControlCallbacks:
restart: Callable[[], None] | None = None
exit: Callable[[], None] | None = None
@dataclass
class DirectTTSExecution:
media_type: str
streaming: bool
audio_generator: Optional[Generator[bytes, None, None]] = None
audio_bytes: Optional[bytes] = None
request_id: Optional[str] = None
@dataclass
class NormalizedEngineRequest:
request_id: str
text: str
text_lang: str
ref_audio_path: str
prompt_lang: str
prompt_text: str = ""
aux_ref_audio_paths: List[str] | None = None
top_k: int = 15
top_p: float = 1.0
temperature: float = 1.0
repetition_penalty: float = 1.35
early_stop_num: int = -1
ready_step: int = 0
text_split_method: str = "cut5"
batch_size: int = 1
batch_threshold: float = 0.75
split_bucket: bool = False
speed_factor: float = 1.0
fragment_interval: float = 0.3
seed: int = -1
media_type: str = "wav"
streaming_mode: bool | int = False
return_fragment: bool = False
fixed_length_chunk: bool = False
response_streaming: bool = False
parallel_infer: bool = False
sample_steps: int = 32
super_sampling: bool = False
overlap_length: int = 2
min_chunk_length: int = 16
timeout_sec: float | None = None
def to_payload(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"text": self.text,
"text_lang": self.text_lang,
"ref_audio_path": self.ref_audio_path,
"aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None,
"prompt_text": self.prompt_text,
"prompt_lang": self.prompt_lang,
"top_k": self.top_k,
"top_p": self.top_p,
"temperature": self.temperature,
"text_split_method": self.text_split_method,
"batch_size": self.batch_size,
"batch_threshold": self.batch_threshold,
"speed_factor": self.speed_factor,
"split_bucket": self.split_bucket,
"fragment_interval": self.fragment_interval,
"seed": self.seed,
"media_type": self.media_type,
"streaming_mode": self.streaming_mode,
"return_fragment": self.return_fragment,
"fixed_length_chunk": self.fixed_length_chunk,
"response_streaming": self.response_streaming,
"parallel_infer": self.parallel_infer,
"repetition_penalty": self.repetition_penalty,
"sample_steps": self.sample_steps,
"super_sampling": self.super_sampling,
"overlap_length": self.overlap_length,
"min_chunk_length": self.min_chunk_length,
"early_stop_num": self.early_stop_num,
"ready_step": self.ready_step,
"timeout_sec": self.timeout_sec,
}
def to_scheduler_spec(self) -> SchedulerRequestSpec:
return SchedulerRequestSpec(
request_id=self.request_id,
ref_audio_path=Path(self.ref_audio_path),
prompt_text=self.prompt_text,
prompt_lang=self.prompt_lang,
text=self.text,
text_lang=self.text_lang,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
early_stop_num=self.early_stop_num,
aux_ref_audio_paths=list(self.aux_ref_audio_paths or []),
ready_step=self.ready_step,
)
@dataclass
class SchedulerDebugExecution:
payload: Dict[str, Any]
@dataclass
class SchedulerSubmitExecution:
audio_bytes: bytes
media_type: str
headers: Dict[str, str]

View File

@ -0,0 +1,363 @@
from __future__ import annotations
import asyncio
import threading
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus
@dataclass
class EnginePolicyConfig:
enabled: bool = True
poll_wait_ms: float = 5.0
decode_backlog_soft_max: int = 0
finalize_pending_soft_max: int = 0
prepare_inflight_soft_max: int = 0
active_decode_soft_max: int = 0
ready_for_prefill_soft_max: int = 0
active_request_soft_max: int = 0
def to_dict(self) -> Dict[str, Any]:
return {
"enabled": bool(self.enabled),
"poll_wait_ms": float(self.poll_wait_ms),
"decode_backlog_soft_max": int(self.decode_backlog_soft_max),
"finalize_pending_soft_max": int(self.finalize_pending_soft_max),
"prepare_inflight_soft_max": int(self.prepare_inflight_soft_max),
"active_decode_soft_max": int(self.active_decode_soft_max),
"ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max),
"active_request_soft_max": int(self.active_request_soft_max),
}
@dataclass
class EngineArbiterConfig:
poll_wait_ms: float = 5.0
decode_burst: int = 4
prepare_aging_ms: float = 10.0
finalize_aging_ms: float = 10.0
def to_dict(self) -> Dict[str, Any]:
return {
"poll_wait_ms": float(self.poll_wait_ms),
"decode_burst": int(self.decode_burst),
"prepare_aging_ms": float(self.prepare_aging_ms),
"finalize_aging_ms": float(self.finalize_aging_ms),
}
@dataclass
class EngineArbiterState:
total_ticks: int = 0
total_idle_ticks: int = 0
total_prepare_dispatches: int = 0
total_decode_dispatches: int = 0
total_decode_runtime_ticks: int = 0
total_finalize_dispatches: int = 0
decode_budget_remaining: int = 0
last_stage: str = "idle"
last_reason: str = "init"
last_observed_at: float = 0.0
last_policy_allowed: bool = True
class EnginePolicyArbiterController:
def __init__(
self,
*,
policy_config: EnginePolicyConfig,
arbiter_config: EngineArbiterConfig,
snapshot_request_registry: Callable[[], Dict[str, Any]],
get_worker_state: Callable[[], Dict[str, Any]],
snapshot_prepare_state: Callable[[], Dict[str, Any]],
snapshot_finalize_state: Callable[[], Dict[str, Any]],
snapshot_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_decode_runtime_state: Callable[[], Dict[str, Any]],
snapshot_job_registry: Callable[[], Dict[str, Any]],
peek_queue_age_ms: Callable[[str], float],
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
) -> None:
self.policy_config = policy_config
self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0)
self.arbiter_config = arbiter_config
self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0)
self.condition = threading.Condition()
self.state = EngineArbiterState(
decode_budget_remaining=int(self.arbiter_config.decode_burst),
last_observed_at=time.perf_counter(),
)
self.snapshot_request_registry = snapshot_request_registry
self.get_worker_state = get_worker_state
self.snapshot_prepare_state = snapshot_prepare_state
self.snapshot_finalize_state = snapshot_finalize_state
self.snapshot_dispatch_state = snapshot_dispatch_state
self.snapshot_decode_runtime_state = snapshot_decode_runtime_state
self.snapshot_job_registry = snapshot_job_registry
self.peek_queue_age_ms = peek_queue_age_ms
self.merge_request_state_profile = merge_request_state_profile
def snapshot_state(self) -> Dict[str, Any]:
with self.condition:
return {
"config": self.arbiter_config.to_dict(),
"total_ticks": int(self.state.total_ticks),
"total_idle_ticks": int(self.state.total_idle_ticks),
"total_prepare_dispatches": int(self.state.total_prepare_dispatches),
"total_decode_dispatches": int(self.state.total_decode_dispatches),
"total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks),
"total_finalize_dispatches": int(self.state.total_finalize_dispatches),
"decode_budget_remaining": int(self.state.decode_budget_remaining),
"last_stage": str(self.state.last_stage),
"last_reason": str(self.state.last_reason),
"last_policy_allowed": bool(self.state.last_policy_allowed),
"last_observed_at": float(self.state.last_observed_at),
}
def notify(self) -> None:
with self.condition:
self.condition.notify_all()
def wait(self) -> None:
with self.condition:
self.condition.wait(timeout=self.arbiter_poll_s)
def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
with self.condition:
self.state.total_ticks += 1
if stage == "idle":
self.state.total_idle_ticks += 1
elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}:
self.state.total_prepare_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "finalize":
self.state.total_finalize_dispatches += 1
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
elif stage == "decode_dispatch":
self.state.total_decode_dispatches += 1
elif stage == "decode_runtime":
self.state.total_decode_runtime_ticks += 1
self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1)
self.state.last_stage = str(stage)
self.state.last_reason = str(reason)
self.state.last_policy_allowed = bool(policy_allowed)
self.state.last_observed_at = time.perf_counter()
def build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
prepare_dispatcher_state = self.snapshot_prepare_state()
finalize_dispatcher_state = self.snapshot_finalize_state()
dispatcher_state = self.snapshot_dispatch_state()
active_requests = list(request_registry.get("active_requests", []))
status_counts: Dict[str, int] = {}
for item in active_requests:
status = str(item.get("status", "UNKNOWN"))
status_counts[status] = status_counts.get(status, 0) + 1
worker_pending_jobs = int(worker_state.get("pending_jobs", 0))
worker_decode_active_size = int(worker_state.get("running_requests", 0))
worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0))
worker_finalize_pending = int(worker_state.get("finalize_pending", 0))
worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0))
engine_decode_runtime_state = self.snapshot_decode_runtime_state()
engine_job_registry = self.snapshot_job_registry()
decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0))
decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0))
return {
"active_request_count": int(len(active_requests)),
"status_counts": status_counts,
"queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)),
"cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)),
"gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)),
"ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)),
"active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)),
"ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)),
"finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)),
"streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)),
"worker_pending_jobs": worker_pending_jobs,
"worker_decode_active_size": worker_decode_active_size,
"worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)),
"worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)),
"engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs,
"engine_decode_runtime_active_request_count": decode_runtime_active_size,
"engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)),
"engine_job_registry_count": int(engine_job_registry.get("job_count", 0)),
"worker_prepare_inflight": worker_prepare_inflight,
"worker_finalize_pending": worker_finalize_pending,
"worker_finalize_inflight": worker_finalize_inflight,
"engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)),
"engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)),
"engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)),
"decode_backlog": int(
decode_runtime_pending_jobs + decode_runtime_active_size
if bool(worker_state.get("engine_decode_control_enabled", False))
else worker_pending_jobs + worker_decode_active_size
),
}
def build_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
counters = self.build_stage_counters(request_registry, worker_state)
config = self.policy_config.to_dict()
blocked_reasons: List[Dict[str, Any]] = []
finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0))
limit_checks = [
("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])),
("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])),
("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])),
("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])),
("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])),
("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])),
]
if bool(config["enabled"]):
for name, value, limit in limit_checks:
if limit > 0 and int(value) >= int(limit):
blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)})
return {
"enabled": bool(config["enabled"]),
"allowed": (not bool(config["enabled"])) or not blocked_reasons,
"blocked_reasons": blocked_reasons,
"config": config,
"metrics": {
"active_request_count": int(counters["active_request_count"]),
"queued_request_count": int(counters["queued_request_count"]),
"ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]),
"active_decode_request_count": int(counters["active_decode_request_count"]),
"engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]),
"engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]),
"decode_backlog": int(counters["decode_backlog"]),
"prepare_inflight": int(counters["worker_prepare_inflight"]),
"finalize_pending": int(finalize_pending_total),
"engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)),
"finalize_inflight": int(counters["worker_finalize_inflight"]),
},
"observed_at": time.perf_counter(),
}
async def wait_for_policy_admission(
self,
*,
request_id: str | None,
timeout_sec: float | None,
) -> tuple[float, Dict[str, Any]]:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
snapshot = self.build_policy_snapshot(request_registry, worker_state)
if not self.policy_config.enabled:
return 0.0, snapshot
start = time.perf_counter()
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
while True:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
snapshot = self.build_policy_snapshot(request_registry, worker_state)
if snapshot["allowed"]:
wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0)
if request_id not in [None, ""]:
self.merge_request_state_profile(
str(request_id),
{
"engine_policy_wait_ms": float(wait_ms),
"engine_policy_snapshot": snapshot,
},
)
return wait_ms, snapshot
now = time.perf_counter()
if deadline is not None and now >= deadline:
blocked_summary = ", ".join(
f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", [])
)
raise TimeoutError(f"engine policy admission timeout ({blocked_summary})")
await asyncio.sleep(self.policy_poll_s)
def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
request_registry = self.snapshot_request_registry()
worker_state = self.get_worker_state()
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
prepare_state = self.snapshot_prepare_state()
prepare_waiting = int(prepare_state.get("waiting_count", 0))
prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0))
prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0))
prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0))
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
decode_runtime_state = self.snapshot_decode_runtime_state()
worker_decode_has_work = bool(decode_runtime_state.get("has_work", False))
worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False))
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio"))
prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text"))
prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec"))
finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
policy_allowed = bool(policy_snapshot.get("allowed", True))
if (
worker_decode_control_enabled
and worker_decode_has_work
and policy_allowed
and decode_budget_remaining > 0
and (worker_running_requests > 0 or worker_pending_jobs > 0)
):
return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state
if (
worker_decode_control_enabled
and worker_pending_jobs > 0
and policy_allowed
and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms)
):
return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state
if (
decode_waiting > 0
and policy_allowed
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
):
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
if (
finalize_waiting > 0
and prepare_ref_spec_waiting > 0
and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0)
):
return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
return "finalize", "finalize_aging", policy_snapshot, worker_state
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0:
return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6):
return "prepare_text", "prepare_aging", policy_snapshot, worker_state
if (
prepare_ref_spec_waiting > 0
and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6)
):
return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state
return "prepare_audio", "prepare_aging", policy_snapshot, worker_state
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
if decode_waiting > 0 and policy_allowed:
return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state
if finalize_waiting > 0:
return "finalize", "finalize_fallback", policy_snapshot, worker_state
if prepare_waiting > 0:
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
return "prepare_text", "prepare_fallback", policy_snapshot, worker_state
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0:
return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state
return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state
return "idle", "no_pending_work", policy_snapshot, worker_state

View File

@ -0,0 +1,382 @@
from __future__ import annotations
import asyncio
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Deque, Dict, Optional, Sequence
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
@dataclass
class DefaultReferenceState:
ref_audio_path: str | None = None
updated_at: float = 0.0
class ReferenceRegistry:
def __init__(self) -> None:
self._lock = threading.Lock()
self._state = DefaultReferenceState()
def set_default(self, ref_audio_path: str) -> DefaultReferenceState:
with self._lock:
self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time())
return self._state
def clear(self) -> DefaultReferenceState:
with self._lock:
self._state = DefaultReferenceState()
return self._state
def get_default(self) -> DefaultReferenceState:
with self._lock:
return DefaultReferenceState(
ref_audio_path=self._state.ref_audio_path,
updated_at=self._state.updated_at,
)
@dataclass
class ModelRegistryState:
t2s_weights_path: str
vits_weights_path: str
generation: int = 0
t2s_generation: int = 0
vits_generation: int = 0
updated_at: float = field(default_factory=time.time)
class ModelRegistry:
def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None:
self._lock = threading.Lock()
self._state = ModelRegistryState(
t2s_weights_path=str(t2s_weights_path),
vits_weights_path=str(vits_weights_path),
)
def snapshot(self) -> ModelRegistryState:
with self._lock:
return ModelRegistryState(
t2s_weights_path=self._state.t2s_weights_path,
vits_weights_path=self._state.vits_weights_path,
generation=self._state.generation,
t2s_generation=self._state.t2s_generation,
vits_generation=self._state.vits_generation,
updated_at=self._state.updated_at,
)
def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState:
with self._lock:
self._state.t2s_weights_path = str(weights_path)
self._state.generation += 1
self._state.t2s_generation += 1
self._state.updated_at = time.time()
return ModelRegistryState(
t2s_weights_path=self._state.t2s_weights_path,
vits_weights_path=self._state.vits_weights_path,
generation=self._state.generation,
t2s_generation=self._state.t2s_generation,
vits_generation=self._state.vits_generation,
updated_at=self._state.updated_at,
)
def mark_vits_reload(self, weights_path: str) -> ModelRegistryState:
with self._lock:
self._state.vits_weights_path = str(weights_path)
self._state.generation += 1
self._state.vits_generation += 1
self._state.updated_at = time.time()
return ModelRegistryState(
t2s_weights_path=self._state.t2s_weights_path,
vits_weights_path=self._state.vits_weights_path,
generation=self._state.generation,
t2s_generation=self._state.t2s_generation,
vits_generation=self._state.vits_generation,
updated_at=self._state.updated_at,
)
class EngineStatus:
NEW = "NEW"
QUEUED = "QUEUED"
VALIDATED = "VALIDATED"
CPU_PREPARING = "CPU_PREPARING"
GPU_PREPARING = "GPU_PREPARING"
READY_FOR_PREFILL = "READY_FOR_PREFILL"
ACTIVE_DECODE = "ACTIVE_DECODE"
READY_FOR_FINALIZE = "READY_FOR_FINALIZE"
FINALIZING = "FINALIZING"
STREAMING = "STREAMING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
@dataclass
class EngineRequestState:
request_id: str
api_mode: str
backend: str
media_type: str
response_streaming: bool
submit_ts: float
deadline_ts: float | None = None
status: str = EngineStatus.NEW
updated_ts: float = 0.0
error: str | None = None
finish_reason: str | None = None
meta: Dict[str, Any] = field(default_factory=dict)
profile: Dict[str, Any] = field(default_factory=dict)
lifecycle_timestamps: Dict[str, float] = field(default_factory=dict)
def to_summary(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"api_mode": self.api_mode,
"backend": self.backend,
"media_type": self.media_type,
"response_streaming": self.response_streaming,
"status": self.status,
"submit_ts": self.submit_ts,
"updated_ts": self.updated_ts,
"deadline_ts": self.deadline_ts,
"error": self.error,
"finish_reason": self.finish_reason,
"meta": dict(self.meta),
"profile": dict(self.profile),
"lifecycle_timestamps": dict(self.lifecycle_timestamps),
}
class EngineRequestRegistry:
def __init__(self, recent_limit: int) -> None:
self.lock = threading.Lock()
self.active_requests: Dict[str, EngineRequestState] = {}
self.recent_requests: Deque[EngineRequestState] = deque()
self.recent_limit = max(1, int(recent_limit))
def register(
self,
*,
request_id: str,
api_mode: str,
backend: str,
media_type: str,
response_streaming: bool,
deadline_ts: float | None = None,
meta: Optional[Dict[str, Any]] = None,
) -> EngineRequestState:
now = time.perf_counter()
state = EngineRequestState(
request_id=request_id,
api_mode=api_mode,
backend=backend,
media_type=media_type,
response_streaming=bool(response_streaming),
submit_ts=now,
deadline_ts=deadline_ts,
updated_ts=now,
meta=dict(meta or {}),
lifecycle_timestamps={EngineStatus.NEW: now},
)
with self.lock:
self.active_requests[request_id] = state
return state
def _move_to_recent_locked(self, state: EngineRequestState) -> None:
self.recent_requests.appendleft(state)
while len(self.recent_requests) > self.recent_limit:
self.recent_requests.pop()
@staticmethod
def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None:
if not extra:
return
payload = dict(extra)
backend = payload.pop("backend", None)
if backend is not None:
state.backend = str(backend)
finish_reason = payload.pop("finish_reason", None)
if finish_reason is not None:
state.finish_reason = str(finish_reason)
error = payload.pop("error", None)
if error is not None:
state.error = str(error)
state.profile.update(payload)
def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
now = time.perf_counter()
with self.lock:
state = self.active_requests.get(request_id)
if state is None:
return
state.status = str(status)
state.updated_ts = now
state.lifecycle_timestamps[str(status)] = now
self._apply_state_extra(state, extra)
def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
if not extra:
return
now = time.perf_counter()
with self.lock:
state = self.active_requests.get(request_id)
if state is None:
for recent_state in self.recent_requests:
if recent_state.request_id == request_id:
state = recent_state
break
if state is None:
return
state.updated_ts = now
self._apply_state_extra(state, extra)
def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
now = time.perf_counter()
with self.lock:
state = self.active_requests.pop(request_id, None)
if state is None:
return
state.status = EngineStatus.COMPLETED
state.updated_ts = now
state.lifecycle_timestamps[EngineStatus.COMPLETED] = now
self._apply_state_extra(state, extra)
self._move_to_recent_locked(state)
def fail(self, request_id: str, error: str) -> None:
now = time.perf_counter()
with self.lock:
state = self.active_requests.pop(request_id, None)
if state is None:
return
state.status = EngineStatus.FAILED
state.updated_ts = now
state.error = str(error)
state.lifecycle_timestamps[EngineStatus.FAILED] = now
self._move_to_recent_locked(state)
def snapshot(self) -> Dict[str, Any]:
with self.lock:
active = [state.to_summary() for state in self.active_requests.values()]
recent = [state.to_summary() for state in list(self.recent_requests)]
recent_limit = self.recent_limit
active.sort(key=lambda item: item["submit_ts"])
return {
"active_count": len(active),
"recent_count": len(recent),
"recent_limit": recent_limit,
"active_requests": active,
"recent_requests": recent,
}
def collect_summaries(self, request_ids: Sequence[str]) -> list[Dict[str, Any]]:
requested = set(request_ids)
results: list[Dict[str, Any]] = []
with self.lock:
for state in self.active_requests.values():
if state.request_id in requested:
results.append(state.to_summary())
existing_ids = {item["request_id"] for item in results}
for state in self.recent_requests:
if state.request_id in requested and state.request_id not in existing_ids:
results.append(state.to_summary())
results.sort(key=lambda item: item["request_id"])
return results
def has_active(self, request_id: str) -> bool:
with self.lock:
return request_id in self.active_requests
@dataclass
class SchedulerPendingJob:
request_id: str
state: T2SRequestState
done_event: threading.Event
done_loop: asyncio.AbstractEventLoop | None
done_future: asyncio.Future | None
enqueue_time: float
speed_factor: float
sample_steps: int
media_type: str
super_sampling: bool = False
admission_wait_ms: float = 0.0
engine_policy_wait_ms: float = 0.0
engine_dispatch_wait_ms: float = 0.0
prepare_wall_ms: float = 0.0
prepare_profile_total_ms: float = 0.0
first_schedule_time: float | None = None
prefill_ms: float = 0.0
merge_ms: float = 0.0
decode_ms: float = 0.0
finalize_wait_ms: float = 0.0
synth_ms: float = 0.0
pack_ms: float = 0.0
decode_steps: int = 0
result_ready_time: float | None = None
result: dict | None = None
sample_rate: int | None = None
audio_data: np.ndarray | None = None
error: str | None = None
engine_request_id: str | None = None
class SchedulerJobRegistry:
def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None:
self._lock = lock
self._job_map: Dict[str, SchedulerPendingJob] = {}
self._total_submitted = 0
self._total_finished = 0
def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None:
with self._lock:
if keep_job:
self._job_map[job.request_id] = job
self._total_submitted += 1
def get(self, request_id: str) -> SchedulerPendingJob | None:
with self._lock:
return self._job_map.get(request_id)
def pop(self, request_id: str) -> SchedulerPendingJob | None:
with self._lock:
return self._job_map.pop(request_id, None)
def remove(self, request_id: str) -> None:
with self._lock:
self._job_map.pop(request_id, None)
def mark_finished(self) -> None:
with self._lock:
self._total_finished += 1
def mark_finished_and_remove(self, request_id: str) -> None:
with self._lock:
self._job_map.pop(request_id, None)
self._total_finished += 1
def is_empty(self) -> bool:
with self._lock:
return not self._job_map
def submitted_count(self) -> int:
with self._lock:
return int(self._total_submitted)
def finished_count(self) -> int:
with self._lock:
return int(self._total_finished)
def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]:
with self._lock:
request_ids = list(self._job_map.keys())
return {
"job_count": int(len(request_ids)),
"request_ids": request_ids[: max(0, int(max_request_ids))],
"total_submitted": int(self._total_submitted),
"total_finished": int(self._total_finished),
}

View File

@ -0,0 +1,362 @@
from __future__ import annotations
import asyncio
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Deque, Dict, List, Optional, Sequence
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import SchedulerPendingJob
class EngineTaskQueueOwner:
def __init__(self, completion_key: str = "total_completed") -> None:
self.condition = threading.Condition()
self.queue: Deque[Any] = deque()
self.total_submitted = 0
self.total_completed = 0
self.peak_waiting = 0
self.completion_key = str(completion_key)
def enqueue(self, item: Any) -> None:
with self.condition:
self.queue.append(item)
self.total_submitted += 1
self.peak_waiting = max(self.peak_waiting, len(self.queue))
self.condition.notify_all()
def enqueue_many(self, items: Sequence[Any]) -> None:
if not items:
return
with self.condition:
for item in items:
self.queue.append(item)
self.total_submitted += len(items)
self.peak_waiting = max(self.peak_waiting, len(self.queue))
self.condition.notify_all()
def pop_left(self) -> Any | None:
with self.condition:
if not self.queue:
return None
return self.queue.popleft()
def pop_left_many(self, max_items: int) -> List[Any]:
limit = max(1, int(max_items))
with self.condition:
if not self.queue:
return []
selected: List[Any] = []
while self.queue and len(selected) < limit:
selected.append(self.queue.popleft())
return selected
def mark_completed(self, count: int = 1, *, notify: bool = False) -> None:
if count <= 0:
return
with self.condition:
self.total_completed += int(count)
if notify:
self.condition.notify_all()
def has_items(self) -> bool:
with self.condition:
return bool(self.queue)
def waiting_count(self) -> int:
with self.condition:
return int(len(self.queue))
def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
with self.condition:
waiting_items = list(self.queue)[: max(0, int(max_request_ids))]
snapshot = {
"waiting_count": int(len(self.queue)),
"waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items],
"peak_waiting": int(self.peak_waiting),
"total_submitted": int(self.total_submitted),
self.completion_key: int(self.total_completed),
}
if extra:
snapshot.update(dict(extra))
return snapshot
def peek_oldest_age_ms(self, timestamp_attr: str) -> float:
with self.condition:
if not self.queue:
return 0.0
enqueue_time = float(getattr(self.queue[0], timestamp_attr))
return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0)
def is_drained(self) -> bool:
with self.condition:
return not self.queue and self.total_submitted == self.total_completed
def take_finalize_batch(
self,
*,
finalize_mode: str,
batch_max_items: int,
batch_wait_s: float,
use_vocoder: bool,
) -> List[SchedulerFinalizeTask]:
with self.condition:
if not self.queue:
return []
selected_tasks = [self.queue.popleft()]
if finalize_mode == "sync" or use_vocoder:
return selected_tasks
if batch_max_items <= 1:
return selected_tasks
first_task = selected_tasks[0]
oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time)
if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s:
self.queue.appendleft(first_task)
return []
while len(selected_tasks) < batch_max_items:
if not self.queue:
break
matched_index = None
for index, task in enumerate(self.queue):
if abs(task.enqueued_time - first_task.enqueued_time) < 1.0:
matched_index = index
break
if matched_index is None:
break
selected_tasks.append(self.queue[matched_index])
del self.queue[matched_index]
return selected_tasks
@dataclass
class EngineDecodeRuntimeState:
pending_jobs: int = 0
pending_request_ids: List[str] = field(default_factory=list)
active_request_count: int = 0
active_request_ids: List[str] = field(default_factory=list)
prefill_done: bool = False
decode_step_index_max: int = 0
total_cycles: int = 0
prefill_cycles: int = 0
step_cycles: int = 0
has_work: bool = False
last_event: str = "init"
updated_at: float = 0.0
class EngineDecodeRuntimeOwner:
def __init__(
self,
*,
get_decode_runtime_counters: Callable[[], Dict[str, int]],
get_micro_batch_wait_s: Callable[[], float],
) -> None:
self.get_decode_runtime_counters = get_decode_runtime_counters
self.get_micro_batch_wait_s = get_micro_batch_wait_s
self.condition = threading.Condition()
self.pending_jobs: Deque[SchedulerPendingJob] = deque()
self.active_batch: T2SActiveBatch | None = None
self.state_lock = threading.Lock()
self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter())
@staticmethod
def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
if active_batch is None:
return {}
decode_step_index_max = 0
if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0:
decode_step_index_max = int(active_batch.step_indices.max().item())
return {
"request_count": int(len(active_batch.request_ids)),
"request_ids": list(active_batch.request_ids),
"prefill_done": bool(active_batch.prefill_done),
"decode_step_index_max": int(decode_step_index_max),
}
def snapshot_pending_queue_state(self) -> Dict[str, Any]:
with self.condition:
return {
"pending_jobs": int(len(self.pending_jobs)),
"pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]],
}
def enqueue_pending_job(self, job: SchedulerPendingJob) -> None:
with self.condition:
self.pending_jobs.append(job)
self.condition.notify_all()
self.refresh_state("engine_decode_pending_enqueue")
def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
with self.condition:
if not self.pending_jobs:
return []
if wait_for_batch:
oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time)
if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s():
return []
pending_jobs = list(self.pending_jobs)
self.pending_jobs.clear()
self.refresh_state("engine_decode_pending_dequeue")
return pending_jobs
def pending_age_ms(self) -> float:
with self.condition:
if not self.pending_jobs:
return 0.0
enqueue_time = float(self.pending_jobs[0].enqueue_time)
return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0)
def has_pending_jobs(self) -> bool:
with self.condition:
return bool(self.pending_jobs)
def get_active_batch(self) -> T2SActiveBatch | None:
return self.active_batch
def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None:
self.active_batch = active_batch
def active_batch_summary(self) -> Dict[str, Any]:
return self.summarize_active_batch(self.active_batch)
def refresh_state(self, last_event: str) -> None:
pending_state = self.snapshot_pending_queue_state()
active_batch_summary = self.active_batch_summary()
worker_decode_counters = self.get_decode_runtime_counters()
with self.state_lock:
self.state.pending_jobs = int(pending_state.get("pending_jobs", 0))
self.state.pending_request_ids = list(pending_state.get("pending_request_ids", []))
self.state.active_request_count = int(active_batch_summary.get("request_count", 0))
self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32]
self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False))
self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0))
self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0))
self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0))
self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0))
self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0))
self.state.last_event = str(last_event)
self.state.updated_at = float(time.perf_counter())
def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None:
if not snapshot:
return
pending_state = self.snapshot_pending_queue_state()
with self.state_lock:
self.state.pending_jobs = int(pending_state.get("pending_jobs", 0))
self.state.pending_request_ids = list(pending_state.get("pending_request_ids", []))
self.state.active_request_count = int(snapshot.get("active_request_count", 0))
self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32]
self.state.prefill_done = bool(snapshot.get("prefill_done", False))
self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0))
self.state.total_cycles = int(snapshot.get("total_cycles", 0))
self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0))
self.state.step_cycles = int(snapshot.get("step_cycles", 0))
self.state.has_work = bool(
pending_state.get("pending_jobs", 0)
or snapshot.get("active_request_count", 0)
or snapshot.get("has_work", False)
)
self.state.last_event = str(snapshot.get("last_event", "unknown"))
self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter()))
def snapshot_state(self) -> Dict[str, Any]:
pending_state = self.snapshot_pending_queue_state()
active_batch_summary = self.active_batch_summary()
worker_decode_counters = self.get_decode_runtime_counters()
with self.state_lock:
return {
"pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)),
"pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)),
"active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)),
"active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)),
"prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)),
"decode_step_index_max": int(active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max)),
"total_cycles": int(worker_decode_counters.get("total_cycles", 0)),
"prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)),
"step_cycles": int(worker_decode_counters.get("step_cycles", 0)),
"has_work": bool(
pending_state.get("pending_jobs", 0)
or active_batch_summary.get("request_count", self.state.active_request_count)
or self.state.has_work
),
"last_event": str(self.state.last_event),
"updated_at": float(self.state.updated_at),
}
@dataclass
class SchedulerFinalizeTask:
request_id: str
item: T2SFinishedItem
enqueued_time: float
@dataclass
class EngineDispatchTask:
request_id: str
state: T2SRequestState
speed_factor: float
sample_steps: int
media_type: str
super_sampling: bool
prepare_wall_ms: float
prepare_profile_total_ms: float
done_loop: asyncio.AbstractEventLoop | None
done_future: asyncio.Future | None
engine_request_id: str | None
timeout_sec: float | None
enqueue_time: float
worker_job: SchedulerPendingJob | None = None
engine_policy_wait_ms: float = 0.0
engine_dispatch_wait_ms: float = 0.0
engine_policy_snapshot: Dict[str, Any] | None = None
error: str | None = None
@dataclass
class EngineGpuPrepareTask:
request_id: str
cpu_stage: PreparedCpuStage
done_loop: asyncio.AbstractEventLoop | None
done_future: asyncio.Future | None
engine_request_id: str | None
enqueue_time: float
phase: str = "audio"
audio_enqueue_time: float = 0.0
audio_start_time: float = 0.0
audio_end_time: float = 0.0
text_enqueue_time: float = 0.0
text_start_time: float = 0.0
text_end_time: float = 0.0
ref_spec_enqueue_time: float = 0.0
ref_spec_start_time: float = 0.0
ref_spec_end_time: float = 0.0
audio_queue_wait_ms: float = 0.0
text_queue_wait_ms: float = 0.0
ref_spec_queue_wait_ms: float = 0.0
admission_wait_ms: float = 0.0
phase_one: Dict[str, Any] | None = None
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None
state_result: T2SRequestState | None = None
cancelled: bool = False
error: str | None = None
@dataclass
class EngineFinalizeQueueState:
waiting_count: int
waiting_request_ids: List[str]
peak_waiting: int
total_submitted: int
total_completed: int
@dataclass
class RuntimeStateCallbacks:
update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None
complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None
fail: Callable[[str, str], None] | None = None
decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None

View File

@ -0,0 +1,63 @@
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_models import (
DirectTTSExecution,
NormalizedEngineRequest,
RuntimeControlCallbacks,
SchedulerDebugExecution,
SchedulerSubmitExecution,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_policy import (
EngineArbiterConfig,
EngineArbiterState,
EnginePolicyArbiterController,
EnginePolicyConfig,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import (
DefaultReferenceState,
EngineRequestRegistry,
EngineRequestState,
EngineStatus,
ModelRegistry,
ModelRegistryState,
ReferenceRegistry,
SchedulerJobRegistry,
SchedulerPendingJob,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_runtime import (
EngineDecodeRuntimeOwner,
EngineDecodeRuntimeState,
EngineDispatchTask,
EngineFinalizeQueueState,
EngineGpuPrepareTask,
EngineTaskQueueOwner,
RuntimeStateCallbacks,
SchedulerFinalizeTask,
)
__all__ = [
"DefaultReferenceState",
"DirectTTSExecution",
"EngineArbiterConfig",
"EngineArbiterState",
"EngineDecodeRuntimeOwner",
"EngineDecodeRuntimeState",
"EngineDispatchTask",
"EngineFinalizeQueueState",
"EngineGpuPrepareTask",
"EnginePolicyArbiterController",
"EnginePolicyConfig",
"EngineRequestRegistry",
"EngineRequestState",
"EngineStatus",
"EngineTaskQueueOwner",
"ModelRegistry",
"ModelRegistryState",
"NormalizedEngineRequest",
"ReferenceRegistry",
"RuntimeControlCallbacks",
"RuntimeStateCallbacks",
"SchedulerDebugExecution",
"SchedulerFinalizeTask",
"SchedulerJobRegistry",
"SchedulerPendingJob",
"SchedulerSubmitExecution",
]

View File

@ -0,0 +1,9 @@
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_delegates import EngineApiDelegates
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_delegates import EngineBridgeDelegates
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime_delegates import EngineRuntimeDelegates
__all__ = [
"EngineApiDelegates",
"EngineBridgeDelegates",
"EngineRuntimeDelegates",
]

View File

@ -0,0 +1,116 @@
from __future__ import annotations
from typing import Any, Callable, Dict
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineTaskQueueOwner
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineStageOrchestrator:
def __init__(
self,
*,
executor: EngineStageExecutor,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
) -> None:
self.executor = executor
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.prepare_text_queue_owner = prepare_text_queue_owner
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None
self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None
self._wait_arbiter: Callable[[], None] | None = None
def bind_arbiter(
self,
*,
notify_arbiter: Callable[[], None],
select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]],
mark_arbiter_tick: Callable[[str, str, bool], None],
wait_arbiter: Callable[[], None],
) -> None:
self.executor.bind_notify_arbiter(notify_arbiter)
self._select_stage = select_stage
self._mark_arbiter_tick = mark_arbiter_tick
self._wait_arbiter = wait_arbiter
def peek_queue_age_ms(self, queue_name: str) -> float:
if queue_name == "prepare":
return max(
self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"),
self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"),
self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"),
)
if queue_name == "prepare_audio":
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "prepare_text":
return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "prepare_ref_spec":
return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time")
if queue_name == "finalize":
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
if queue_name == "decode_runtime_pending":
return self.decode_runtime_owner.pending_age_ms()
return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time")
def has_pending_work(self) -> bool:
if self.scheduler_worker.is_engine_decode_control_enabled():
if self.decode_runtime_owner.has_pending_jobs():
return True
if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get(
"active_request_count", 0
) > 0:
return True
if self.prepare_queue_owner.has_items():
return True
if self.prepare_text_queue_owner.has_items():
return True
if self.prepare_ref_spec_queue_owner.has_items():
return True
if self.finalize_queue_owner.has_items():
return True
return self.dispatch_queue_owner.has_items()
def run_engine_arbiter_loop(self) -> None:
if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None:
raise RuntimeError("arbiter callbacks are not bound")
while True:
if not self.has_pending_work():
self._mark_arbiter_tick("idle", "no_pending_work", True)
self._wait_arbiter()
continue
stage, reason, policy_snapshot, worker_state = self._select_stage()
policy_allowed = bool(policy_snapshot.get("allowed", True))
executed = False
if stage == "prepare":
executed = self.executor.run_engine_prepare_once()
elif stage == "prepare_audio":
executed = self.executor.run_engine_prepare_audio_once()
elif stage == "prepare_text":
executed = self.executor.run_engine_prepare_text_once()
elif stage == "prepare_ref_spec":
executed = self.executor.run_engine_prepare_ref_spec_once()
elif stage == "finalize":
executed = self.executor.run_engine_finalize_once()
elif stage == "decode_dispatch":
executed = self.executor.run_engine_dispatch_once(policy_snapshot, worker_state)
elif stage == "decode_runtime":
executed = self.executor.run_engine_decode_runtime_once()
if not executed:
self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed)
self._wait_arbiter()
continue
self._mark_arbiter_tick(stage, reason, policy_allowed)

View File

@ -0,0 +1,53 @@
from __future__ import annotations
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution
class EnginePublicInterface:
PUBLIC_API_METHODS = (
"run_direct_tts_async",
"run_scheduler_submit",
"run_scheduler_debug",
"get_runtime_state",
"set_refer_audio",
"set_gpt_weights",
"set_sovits_weights",
"handle_control",
)
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
return await self.api_facade.run_direct_tts_async(req)
async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
return await self.api_facade.run_scheduler_submit(payload)
def get_runtime_state(self) -> dict:
return self.runtime_facade.get_runtime_state()
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
return self.runtime_facade.set_refer_audio(refer_audio_path)
def set_gpt_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_gpt_weights(weights_path)
def set_sovits_weights(self, weights_path: str) -> dict:
return self.runtime_facade.set_sovits_weights(weights_path)
def handle_control(self, command: str) -> None:
self.runtime_facade.handle_control(command)
class EngineCompatInterface:
COMPAT_API_METHODS = (
"run_direct_tts",
"get_scheduler_state",
)
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
return self.api_facade.run_direct_tts(req)
def get_scheduler_state(self) -> dict:
return self.runtime_facade.get_scheduler_state()

View File

@ -0,0 +1,198 @@
from __future__ import annotations
import os
import signal
import sys
from typing import Any, Dict, Optional
class EngineRuntimeFacade:
def __init__(self, owner: Any) -> None:
self.owner = owner
@property
def tts(self):
return self.owner.tts
@property
def reference_registry(self):
return self.owner.reference_registry
@property
def model_registry(self):
return self.owner.model_registry
@property
def scheduler_worker(self):
return self.owner.scheduler_worker
@property
def engine_decode_runtime_owner(self):
return self.owner.engine_decode_runtime_owner
@property
def engine_policy_arbiter(self):
return self.owner.engine_policy_arbiter
@property
def management_lock(self):
return self.owner.management_lock
@property
def direct_tts_lock(self):
return self.owner.direct_tts_lock
@property
def control_callbacks(self):
return self.owner.control_callbacks
@staticmethod
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
if component is None or not hasattr(component, "snapshot"):
return None
try:
return dict(component.snapshot())
except Exception:
return None
def _build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state)
def _build_engine_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state)
def _build_stage_summary(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
counters = self._build_stage_counters(request_registry, worker_state)
bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None))
ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None))
text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None))
return {
**counters,
"engine_drained": bool(self.owner._is_engine_drained()),
"admission_config": {
"decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)),
"finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)),
},
"engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state),
"engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(),
"engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(),
"engine_job_registry": self.owner._snapshot_engine_job_registry(),
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
"engine_prepare_state": self.owner._snapshot_engine_prepare_state(),
"engine_finalize_state": self.owner._snapshot_engine_finalize_state(),
"engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(),
"active_batch": dict(worker_state.get("active_batch") or {}),
"prepare_state": dict(worker_state.get("prepare_state") or {}),
"bert_batch_worker_state": bert_worker_state,
"ref_semantic_worker_state": ref_semantic_worker_state,
"text_preprocessor_state": text_preprocessor_state,
}
def get_scheduler_state(self) -> dict:
return self.scheduler_worker.snapshot()
def get_runtime_state(self) -> dict:
model_state = self.model_registry.snapshot()
default_ref = self.reference_registry.get_default()
scheduler_state = self.get_scheduler_state()
request_registry = self.owner._snapshot_request_registry()
engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state)
engine_arbiter_state = self.owner._snapshot_engine_arbiter_state()
engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state()
engine_job_registry = self.owner._snapshot_engine_job_registry()
engine_prepare_state = self.owner._snapshot_engine_prepare_state()
engine_finalize_state = self.owner._snapshot_engine_finalize_state()
engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state()
engine_drained = self.owner._is_engine_drained()
return {
"message": "success",
"default_reference": {
"ref_audio_path": default_ref.ref_audio_path,
"updated_at": default_ref.updated_at,
},
"model_registry": {
"generation": model_state.generation,
"t2s_generation": model_state.t2s_generation,
"vits_generation": model_state.vits_generation,
"t2s_weights_path": model_state.t2s_weights_path,
"vits_weights_path": model_state.vits_weights_path,
"updated_at": model_state.updated_at,
},
"worker_state": scheduler_state,
"engine_policy": engine_policy,
"engine_arbiter_state": engine_arbiter_state,
"engine_decode_runtime_state": engine_decode_runtime_state,
"engine_job_registry": engine_job_registry,
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
"engine_prepare_state": engine_prepare_state,
"engine_finalize_state": engine_finalize_state,
"engine_dispatcher_state": engine_dispatcher_state,
"engine_drained": bool(engine_drained),
"request_registry": request_registry,
"stage_summary": self._build_stage_summary(request_registry, scheduler_state),
}
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec):
raise TimeoutError("scheduler worker did not drain before model reload")
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
if refer_audio_path in [None, ""]:
state = self.reference_registry.clear()
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
if not os.path.exists(str(refer_audio_path)):
raise FileNotFoundError(f"{refer_audio_path} not exists")
with self.management_lock:
with self.direct_tts_lock:
self.tts.set_ref_audio(str(refer_audio_path))
state = self.reference_registry.set_default(str(refer_audio_path))
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
def set_gpt_weights(self, weights_path: str) -> dict:
if weights_path in ["", None]:
raise ValueError("gpt weight path is required")
with self.management_lock:
self._wait_for_safe_reload()
with self.direct_tts_lock:
self.tts.init_t2s_weights(weights_path)
self.tts.refresh_runtime_components()
state = self.model_registry.mark_t2s_reload(str(weights_path))
return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation}
def set_sovits_weights(self, weights_path: str) -> dict:
if weights_path in ["", None]:
raise ValueError("sovits weight path is required")
with self.management_lock:
self._wait_for_safe_reload()
with self.direct_tts_lock:
self.tts.init_vits_weights(weights_path)
self.tts.refresh_runtime_components()
state = self.model_registry.mark_vits_reload(str(weights_path))
return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation}
def handle_control(self, command: str) -> None:
if command == "restart":
if self.control_callbacks.restart is None:
os.execl(sys.executable, sys.executable, *sys.argv)
self.control_callbacks.restart()
return
if command == "exit":
if self.control_callbacks.exit is None:
os.kill(os.getpid(), signal.SIGTERM)
return
self.control_callbacks.exit()
return
raise ValueError(f"unsupported command: {command}")

View File

@ -0,0 +1,46 @@
from __future__ import annotations
from typing import Any, Dict
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
class EngineRuntimeDelegates:
@staticmethod
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
return EngineRuntimeFacade._safe_component_snapshot(component)
def _build_stage_counters(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_counters(request_registry, worker_state)
def _build_engine_policy_snapshot(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state)
async def _wait_for_engine_policy_admission(
self,
*,
request_id: str | None,
timeout_sec: float | None,
) -> tuple[float, Dict[str, Any]]:
return await self.engine_policy_arbiter.wait_for_policy_admission(
request_id=request_id,
timeout_sec=timeout_sec,
)
def _build_stage_summary(
self,
request_registry: Dict[str, Any],
worker_state: Dict[str, Any],
) -> Dict[str, Any]:
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)

View File

@ -0,0 +1,172 @@
from __future__ import annotations
import asyncio
from typing import Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineDecodeRuntimeOwner,
EngineDispatchTask,
EngineTaskQueueOwner,
SchedulerFinalizeTask,
SchedulerPendingJob,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_orchestration import EngineStageOrchestrator
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineStageCoordinator:
def __init__(
self,
*,
tts: TTS,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None],
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
fail_request_state: Callable[[str, str], None],
get_engine_job: Callable[[str], SchedulerPendingJob | None],
register_engine_job: Callable[[SchedulerPendingJob], None],
fail_engine_jobs: Callable[[List[str], str], None],
complete_engine_job: Callable[..., None],
add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None],
add_engine_merge_time: Callable[[List[str], float], None],
add_engine_decode_time: Callable[[List[str], float], None],
enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None],
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
) -> None:
self.executor = EngineStageExecutor(
tts=tts,
scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner,
prepare_text_queue_owner=prepare_text_queue_owner,
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner,
update_request_state=update_request_state,
merge_request_state_profile=merge_request_state_profile,
fail_request_state=fail_request_state,
get_engine_job=get_engine_job,
register_engine_job=register_engine_job,
fail_engine_jobs=fail_engine_jobs,
complete_engine_job=complete_engine_job,
add_engine_prefill_time=add_engine_prefill_time,
add_engine_merge_time=add_engine_merge_time,
add_engine_decode_time=add_engine_decode_time,
enqueue_engine_finished_items=enqueue_engine_finished_items,
snapshot_engine_dispatch_state=snapshot_engine_dispatch_state,
snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state,
)
self.orchestrator = EngineStageOrchestrator(
executor=self.executor,
scheduler_worker=scheduler_worker,
prepare_queue_owner=prepare_queue_owner,
prepare_text_queue_owner=prepare_text_queue_owner,
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
finalize_queue_owner=finalize_queue_owner,
dispatch_queue_owner=dispatch_queue_owner,
decode_runtime_owner=decode_runtime_owner,
snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state,
)
def bind_arbiter(
self,
*,
notify_arbiter: Callable[[], None],
select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]],
mark_arbiter_tick: Callable[[str, str, bool], None],
wait_arbiter: Callable[[], None],
) -> None:
self.orchestrator.bind_arbiter(
notify_arbiter=notify_arbiter,
select_stage=select_stage,
mark_arbiter_tick=mark_arbiter_tick,
wait_arbiter=wait_arbiter,
)
async def prepare_state_via_engine_gpu_queue(
self,
*,
spec,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
return await self.executor.prepare_state_via_engine_gpu_queue(
spec=spec,
prepare_submit_at=prepare_submit_at,
engine_request_id=engine_request_id,
)
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
self.executor.enqueue_worker_finished_for_finalize(tasks)
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
return self.executor.take_engine_finalize_batch_nonblocking()
async def enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
super_sampling: bool,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
return await self.executor.enqueue_prepared_state_for_dispatch(
state=state,
speed_factor=speed_factor,
sample_steps=sample_steps,
media_type=media_type,
super_sampling=super_sampling,
prepare_wall_ms=prepare_wall_ms,
prepare_profile_total_ms=prepare_profile_total_ms,
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id,
timeout_sec=timeout_sec,
)
def peek_queue_age_ms(self, queue_name: str) -> float:
return self.orchestrator.peek_queue_age_ms(queue_name)
def has_pending_work(self) -> bool:
return self.orchestrator.has_pending_work()
def run_engine_prepare_once(self) -> bool:
return self.executor.run_engine_prepare_once()
def run_engine_prepare_audio_once(self) -> bool:
return self.executor.run_engine_prepare_audio_once()
def run_engine_prepare_text_once(self) -> bool:
return self.executor.run_engine_prepare_text_once()
def run_engine_prepare_ref_spec_once(self) -> bool:
return self.executor.run_engine_prepare_ref_spec_once()
def run_engine_finalize_once(self) -> bool:
return self.executor.run_engine_finalize_once()
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
return self.executor.run_engine_dispatch_once(policy_snapshot, worker_state)
def run_engine_decode_runtime_once(self) -> bool:
return self.executor.run_engine_decode_runtime_once()
def run_engine_arbiter_loop(self) -> None:
self.orchestrator.run_engine_arbiter_loop()

View File

@ -0,0 +1,40 @@
from __future__ import annotations
class EngineDecodeStageMixin:
def run_engine_decode_runtime_once(self) -> bool:
if not self.scheduler_worker.is_engine_decode_control_enabled():
return False
runtime_state = self.snapshot_engine_decode_runtime_state()
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
)
result = self.scheduler_worker.execute_decode_cycle(
pending_jobs=pending_jobs,
active_batch=self.decode_runtime_owner.get_active_batch(),
external_bookkeeping=True,
)
prefill_phase = dict(result.get("prefill_phase") or {})
if prefill_phase.get("error"):
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
else:
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
self.add_engine_merge_time(
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
float(prefill_phase.get("merge_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
decode_phase = dict(result.get("decode_phase") or {})
if decode_phase.get("error"):
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
else:
self.add_engine_decode_time(
list(decode_phase.get("request_ids") or []),
float(decode_phase.get("decode_elapsed_s", 0.0)),
)
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
if result.get("executed", False):
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
return bool(result.get("executed", False))

View File

@ -0,0 +1,100 @@
from __future__ import annotations
import asyncio
import time
from typing import Dict
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask
class EngineDispatchStageMixin:
async def enqueue_prepared_state_for_dispatch(
self,
*,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
super_sampling: bool,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None,
done_future: asyncio.Future | None,
engine_request_id: str | None,
timeout_sec: float | None,
) -> EngineDispatchTask:
if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
error = RuntimeError("ref_spec async stage failed before dispatch")
self.fail_request_state(engine_request_id or state.request_id, str(error))
raise error
task = EngineDispatchTask(
request_id=state.request_id,
state=state,
speed_factor=float(speed_factor),
sample_steps=int(sample_steps),
media_type=media_type,
super_sampling=bool(super_sampling),
prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms),
done_loop=done_loop,
done_future=done_future,
engine_request_id=engine_request_id or state.request_id,
timeout_sec=timeout_sec,
enqueue_time=time.perf_counter(),
)
self.dispatch_queue_owner.enqueue(task)
self.notify_arbiter()
self.merge_request_state_profile(
task.engine_request_id or task.request_id,
{
"engine_dispatch_queue_depth_on_enqueue": int(
self.snapshot_engine_dispatch_state()["waiting_count"]
),
},
)
return task
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, object], worker_state: Dict[str, object]) -> bool:
if not bool(policy_snapshot.get("allowed", True)):
return False
dispatch_task = self.dispatch_queue_owner.pop_left()
if dispatch_task is None:
return False
dispatched_at = time.perf_counter()
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
try:
worker_job = self.scheduler_worker.submit(
state=dispatch_task.state,
speed_factor=dispatch_task.speed_factor,
sample_steps=dispatch_task.sample_steps,
media_type=dispatch_task.media_type,
super_sampling=dispatch_task.super_sampling,
prepare_wall_ms=dispatch_task.prepare_wall_ms,
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
done_loop=dispatch_task.done_loop,
done_future=dispatch_task.done_future,
engine_request_id=dispatch_task.engine_request_id,
timeout_sec=dispatch_task.timeout_sec,
skip_capacity_wait=True,
admission_wait_ms_override=0.0,
admission_snapshot_override=dict(worker_state),
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
)
dispatch_task.worker_job = worker_job
self.register_engine_job(worker_job)
if self.scheduler_worker.is_engine_decode_control_enabled():
self.decode_runtime_owner.enqueue_pending_job(worker_job)
self.notify_arbiter()
self.dispatch_queue_owner.mark_completed(1)
return True
except Exception as exc:
dispatch_task.error = str(exc)
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
self._notify_dispatch_error(dispatch_task, exc)
return True

View File

@ -0,0 +1,74 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
EngineDecodeRuntimeOwner,
EngineTaskQueueOwner,
SchedulerFinalizeTask,
SchedulerPendingJob,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_decode import EngineDecodeStageMixin
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_dispatch import EngineDispatchStageMixin
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_finalize import EngineFinalizeStageMixin
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_futures import EngineStageFutureMixin
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_prepare import EnginePrepareStageMixin
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
class EngineStageExecutor(
EngineStageFutureMixin,
EnginePrepareStageMixin,
EngineFinalizeStageMixin,
EngineDispatchStageMixin,
EngineDecodeStageMixin,
):
def __init__(
self,
*,
tts: TTS,
scheduler_worker: UnifiedSchedulerWorker,
prepare_queue_owner: EngineTaskQueueOwner,
prepare_text_queue_owner: EngineTaskQueueOwner,
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
finalize_queue_owner: EngineTaskQueueOwner,
dispatch_queue_owner: EngineTaskQueueOwner,
decode_runtime_owner: EngineDecodeRuntimeOwner,
update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None],
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
fail_request_state: Callable[[str, str], None],
get_engine_job: Callable[[str], SchedulerPendingJob | None],
register_engine_job: Callable[[SchedulerPendingJob], None],
fail_engine_jobs: Callable[[List[str], str], None],
complete_engine_job: Callable[..., None],
add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None],
add_engine_merge_time: Callable[[List[str], float], None],
add_engine_decode_time: Callable[[List[str], float], None],
enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None],
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
) -> None:
self.tts = tts
self.scheduler_worker = scheduler_worker
self.prepare_queue_owner = prepare_queue_owner
self.prepare_text_queue_owner = prepare_text_queue_owner
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
self.finalize_queue_owner = finalize_queue_owner
self.dispatch_queue_owner = dispatch_queue_owner
self.decode_runtime_owner = decode_runtime_owner
self.update_request_state = update_request_state
self.merge_request_state_profile = merge_request_state_profile
self.fail_request_state = fail_request_state
self.get_engine_job = get_engine_job
self.register_engine_job = register_engine_job
self.fail_engine_jobs = fail_engine_jobs
self.complete_engine_job = complete_engine_job
self.add_engine_prefill_time = add_engine_prefill_time
self.add_engine_merge_time = add_engine_merge_time
self.add_engine_decode_time = add_engine_decode_time
self.enqueue_engine_finished_items = enqueue_engine_finished_items
self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
self._notify_arbiter: Callable[[], None] | None = None

View File

@ -0,0 +1,103 @@
from __future__ import annotations
import time
from typing import List
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class EngineFinalizeStageMixin:
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
self.update_request_state(
job.engine_request_id,
EngineStatus.READY_FOR_FINALIZE,
{
"finish_reason": task.item.finish_reason,
"semantic_len": int(task.item.semantic_tokens.shape[0]),
"finish_idx": int(task.item.finish_idx),
},
)
self.finalize_queue_owner.enqueue_many(tasks)
self.notify_arbiter()
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
return self.finalize_queue_owner.take_finalize_batch(
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
use_vocoder=bool(self.tts.configs.use_vocoder),
)
def run_engine_finalize_once(self) -> bool:
tasks = self.take_engine_finalize_batch_nonblocking()
if not tasks:
return False
ready_tasks: List[SchedulerFinalizeTask] = []
failed_tasks: List[SchedulerFinalizeTask] = []
deferred_tasks: List[SchedulerFinalizeTask] = []
for task in tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
failed_tasks.append(task)
continue
if job.state.refer_spec is None:
deferred_tasks.append(task)
self.merge_request_state_profile(
job.engine_request_id or job.request_id,
{
"engine_finalize_ref_spec_blocked": 1.0,
},
)
continue
ready_tasks.append(task)
if deferred_tasks:
self.finalize_queue_owner.enqueue_many(deferred_tasks)
if failed_tasks:
self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed")
if not ready_tasks:
self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True)
return False
self.scheduler_worker.begin_finalize_execution(len(ready_tasks))
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
for task in ready_tasks:
job = self.get_engine_job(task.request_id)
if job is None:
continue
jobs_and_items.append((job, task.item))
if not jobs_and_items:
return False
now = time.perf_counter()
for task in ready_tasks:
job = self.get_engine_job(task.request_id)
if job is not None:
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
for job, item in jobs_and_items:
self.update_request_state(
job.engine_request_id,
EngineStatus.FINALIZING,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
},
)
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
for job, _ in jobs_and_items:
job.synth_ms += float(synth_ms)
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc))
finally:
self.scheduler_worker.end_finalize_execution(len(ready_tasks))
self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True)
return True

View File

@ -0,0 +1,59 @@
from __future__ import annotations
import asyncio
from typing import Callable
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineGpuPrepareTask
class EngineStageFutureMixin:
def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None:
self._notify_arbiter = notify_arbiter
def notify_arbiter(self) -> None:
if self._notify_arbiter is not None:
self._notify_arbiter()
@staticmethod
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
if future.done():
return
future.set_exception(error)
@staticmethod
def _resolve_prepare_future(
future: asyncio.Future,
payload: tuple[T2SRequestState, float, float],
) -> None:
if future.done():
return
future.set_result(payload)
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
except RuntimeError:
pass
def _notify_prepare_result(
self,
task: EngineGpuPrepareTask,
payload: tuple[T2SRequestState, float, float],
) -> None:
if task.done_loop is None or task.done_future is None:
return
try:
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
except RuntimeError:
pass

View File

@ -0,0 +1,306 @@
from __future__ import annotations
import asyncio
import os
import time
from typing import Any
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepareTask, EngineStatus
class EnginePrepareStageMixin:
def _prepare_waiting_total(self) -> int:
return (
int(self.prepare_queue_owner.waiting_count())
+ int(self.prepare_text_queue_owner.waiting_count())
+ int(self.prepare_ref_spec_queue_owner.waiting_count())
)
async def _wait_prepare_queue_admission(self) -> float:
soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0")))
if soft_max <= 0:
return 0.0
poll_s = max(
0.0005,
float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0,
)
wait_start = time.perf_counter()
while self._prepare_waiting_total() >= soft_max:
await asyncio.sleep(poll_s)
return max(0.0, (time.perf_counter() - wait_start) * 1000.0)
async def prepare_state_via_engine_gpu_queue(
self,
*,
spec: Any,
prepare_submit_at: float,
engine_request_id: str | None,
) -> tuple[T2SRequestState, float, float]:
prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission()
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
if engine_request_id not in [None, ""]:
self.update_request_state(
str(engine_request_id),
EngineStatus.GPU_PREPARING,
{
"engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms),
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
},
)
loop = asyncio.get_running_loop()
done_future = loop.create_future()
task = EngineGpuPrepareTask(
request_id=spec.request_id,
cpu_stage=cpu_stage,
done_loop=loop,
done_future=done_future,
engine_request_id=engine_request_id or spec.request_id,
enqueue_time=time.perf_counter(),
phase="audio",
audio_enqueue_time=time.perf_counter(),
admission_wait_ms=float(prepare_queue_admission_wait_ms),
)
self.prepare_queue_owner.enqueue(task)
self.notify_arbiter()
return await done_future
def _should_chain_prepare_text_after_audio(self) -> bool:
if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}:
return False
if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items():
return False
decode_runtime_state = self.snapshot_engine_decode_runtime_state()
if bool(decode_runtime_state.get("has_work", False)):
return False
return True
def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None:
if task.state_result is None or task.ref_spec_result is None:
return
self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms),
"ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)),
"ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)),
"ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)),
"ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)),
"ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)),
"ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)),
},
)
def _mark_ref_spec_async_failed(
self,
task: EngineGpuPrepareTask,
error: Exception,
*,
queue_wait_ms: float,
) -> None:
task.error = str(error)
task.cancelled = True
if task.state_result is not None:
task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"ref_spec_async_failed": 1.0,
"engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms),
},
)
self.fail_request_state(task.engine_request_id or task.request_id, str(error))
self.fail_engine_jobs([task.request_id], str(error))
self.notify_arbiter()
def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
for task in tasks:
task.audio_start_time = float(now)
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks]))
completed_count = 0
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
task.audio_end_time = time.perf_counter()
if isinstance(result, Exception):
task.error = str(result)
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
self._notify_prepare_error(task, result)
completed_count += 1
continue
task.audio_queue_wait_ms = float(queue_wait_ms)
task.phase_one = result
task.phase = "text"
task.enqueue_time = time.perf_counter()
task.text_enqueue_time = float(task.enqueue_time)
task.ref_spec_enqueue_time = float(task.enqueue_time)
self.prepare_text_queue_owner.enqueue(task)
self.prepare_ref_spec_queue_owner.enqueue(task)
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(queue_wait_ms),
"engine_prepare_audio_batch_size": float(len(tasks)),
"engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
"engine_prepare_audio_start_ts": float(task.audio_start_time),
"engine_prepare_audio_end_ts": float(task.audio_end_time),
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
},
)
completed_count += 1
self.prepare_queue_owner.mark_completed(completed_count)
if completed_count > 0 and self._should_chain_prepare_text_after_audio():
self._run_engine_prepare_text_once(min(batch_max_items, completed_count))
return True
if completed_count > 0:
self.notify_arbiter()
return True
def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
for task in tasks:
task.text_start_time = float(now)
items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None]
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items))
completed_count = 0
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
task.text_end_time = time.perf_counter()
if isinstance(result, Exception):
task.error = str(result)
task.cancelled = True
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
self._notify_prepare_error(task, result)
completed_count += 1
continue
task.text_queue_wait_ms = float(queue_wait_ms)
state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases(
task.cpu_stage,
task.phase_one or {},
result,
extra_profile={
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
"engine_prepare_audio_batch_size": float(len(tasks)),
"engine_prepare_text_batch_size": float(len(tasks)),
"engine_prepare_audio_phase_mode": 2.0,
"engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)),
"engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
"engine_prepare_text_phase_batch_size": float(len(tasks)),
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
"engine_prepare_audio_start_ts": float(task.audio_start_time),
"engine_prepare_audio_end_ts": float(task.audio_end_time),
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
"engine_prepare_text_start_ts": float(task.text_start_time),
"engine_prepare_text_end_ts": float(task.text_end_time),
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
},
)
task.state_result = state
self._maybe_apply_ref_spec_to_state(task)
state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks))
if task.engine_request_id not in [None, ""]:
self.merge_request_state_profile(
str(task.engine_request_id),
{
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
"engine_gpu_prepare_batch_size": float(len(tasks)),
},
)
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
completed_count += 1
self.prepare_text_queue_owner.mark_completed(completed_count)
return True
def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool:
tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items)
if not tasks:
return False
now = time.perf_counter()
runnable_tasks: list[EngineGpuPrepareTask] = []
queue_wait_ms_list: list[float] = []
completed_count = 0
for task in tasks:
if task.cancelled or task.phase_one is None:
completed_count += 1
continue
task.ref_spec_start_time = float(now)
runnable_tasks.append(task)
queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0))
if not runnable_tasks:
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
return True
batch_results = asyncio.run(
self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks])
)
for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results):
task.ref_spec_end_time = time.perf_counter()
task.ref_spec_queue_wait_ms = float(queue_wait_ms)
if isinstance(result, Exception):
self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms))
completed_count += 1
continue
task.ref_spec_result = result
self._maybe_apply_ref_spec_to_state(task)
if task.state_result is not None:
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time)
task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time)
task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time)
completed_count += 1
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
return True
def run_engine_prepare_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1))
audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
if self.prepare_text_queue_owner.has_items() and (
not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms
):
return self._run_engine_prepare_text_once(batch_max_items)
if self.prepare_queue_owner.has_items():
return self._run_engine_prepare_audio_once(batch_max_items)
if self.prepare_ref_spec_queue_owner.has_items():
return self._run_engine_prepare_ref_spec_once(batch_max_items)
if self.prepare_text_queue_owner.has_items():
return self._run_engine_prepare_text_once(batch_max_items)
if self.prepare_ref_spec_queue_owner.has_items():
return self._run_engine_prepare_ref_spec_once(batch_max_items)
return False
def run_engine_prepare_audio_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
def run_engine_prepare_text_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
def run_engine_prepare_ref_spec_once(self) -> bool:
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))

View File

@ -0,0 +1,71 @@
from __future__ import annotations
import os
import threading
from typing import Callable, List
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_completion import WorkerCompletionBridge
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_decode import WorkerDecodeExecutor, WorkerDecodeLegacyShell, WorkerDecodeRuntimeTracker
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_execution import WorkerExecutionMixin
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_finalize import WorkerFinalizeExecutor
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_prepare import WorkerPrepareExecutor
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_runtime import WorkerRuntimeBookkeepingMixin
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_submit import WorkerSubmitLifecycleMixin
class UnifiedSchedulerWorker(
WorkerSubmitLifecycleMixin,
WorkerRuntimeBookkeepingMixin,
WorkerExecutionMixin,
):
def __init__(
self,
tts: TTS,
max_steps: int = 1500,
micro_batch_wait_ms: int = 5,
runtime_callbacks: RuntimeStateCallbacks | None = None,
external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None,
):
self.tts = tts
self.max_steps = int(max_steps)
self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
self.condition = threading.Condition()
self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks)
self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps)
self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s)
self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks)
self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change)
self.finalize_executor = WorkerFinalizeExecutor(
tts,
on_state_change=self._notify_worker_state_change,
external_submit=external_finalize_submit,
)
self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0")))
self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0")))
self.engine_decode_control_enabled = (
str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "1")).strip().lower() in {"1", "true", "yes", "on"}
)
self.job_registry = SchedulerJobRegistry(self.condition)
self.worker_thread: threading.Thread | None = None
if not self.engine_decode_control_enabled:
self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True)
self.worker_thread.start()
self.finalize_threads = []
if external_finalize_submit is None:
self.finalize_threads = [
threading.Thread(
target=self._run_finalize_loop,
name=f"unified-t2s-finalize-{worker_index}",
daemon=True,
)
for worker_index in range(self.finalize_executor.get_worker_count())
]
for finalize_thread in self.finalize_threads:
finalize_thread.start()
def _notify_worker_state_change(self) -> None:
with self.condition:
self.condition.notify_all()

View File

@ -0,0 +1,198 @@
from __future__ import annotations
import threading
import time
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerJobRegistry, SchedulerPendingJob
class WorkerCompletionBridge:
def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None:
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
@staticmethod
def _resolve_done_future(job: SchedulerPendingJob) -> None:
future = job.done_future
if future is None or future.done():
return
future.set_result(job)
def notify_done_future(self, job: SchedulerPendingJob) -> None:
if job.done_loop is None or job.done_future is None:
return
try:
job.done_loop.call_soon_threadsafe(self._resolve_done_future, job)
except RuntimeError:
pass
def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None:
if request_id is None or self.runtime_callbacks.complete is None:
return
self.runtime_callbacks.complete(request_id, extra)
def runtime_fail(self, request_id: str | None, error: str) -> None:
if request_id is None or self.runtime_callbacks.fail is None:
return
self.runtime_callbacks.fail(request_id, error)
@staticmethod
def build_completed_job_result(
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
audio_data: np.ndarray,
finished_at: float | None = None,
) -> Dict[str, Any]:
finished_at = float(time.perf_counter() if finished_at is None else finished_at)
queue_wait_ms = 0.0
if job.first_schedule_time is not None:
queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0)
worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0)
worker_residual_ms = max(
0.0,
worker_total_ms
- queue_wait_ms
- job.prefill_ms
- job.merge_ms
- job.decode_ms
- job.finalize_wait_ms
- job.synth_ms,
)
worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms)
job.sample_rate = int(sample_rate)
job.audio_data = audio_data
job.result_ready_time = finished_at
prepare_profile = dict(job.state.prepare_profile)
result = {
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"decode_admission_wait_ms": float(job.admission_wait_ms),
"engine_policy_wait_ms": float(job.engine_policy_wait_ms),
"engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms),
"prepare_ms": job.prepare_wall_ms,
"prepare_wall_ms": job.prepare_wall_ms,
"prepare_profile_total_ms": job.prepare_profile_total_ms,
"prepare_profile": prepare_profile,
"queue_wait_ms": queue_wait_ms,
"prefill_ms": job.prefill_ms,
"merge_ms": job.merge_ms,
"decode_ms": job.decode_ms,
"finalize_wait_ms": job.finalize_wait_ms,
"synth_ms": job.synth_ms,
"worker_residual_ms": worker_residual_ms,
"worker_other_ms": worker_other_ms,
"worker_total_ms": worker_total_ms,
"decode_steps": int(job.decode_steps),
"sample_rate": int(sample_rate),
"media_type": job.media_type,
}
job.result = result
return result
@staticmethod
def build_runtime_complete_payload(
job: SchedulerPendingJob,
item: T2SFinishedItem,
*,
sample_rate: int,
) -> Dict[str, Any]:
return {
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"sample_rate": int(sample_rate),
"worker_profile": dict(job.result or {}),
}
def complete_job(
self,
job: SchedulerPendingJob,
*,
runtime_request_id: str | None,
runtime_extra: Optional[Dict[str, Any]] = None,
remove_job: Callable[[], None] | None = None,
on_job_finished: Callable[[], None] | None = None,
notify_waiters: Callable[[], None] | None = None,
) -> None:
job.done_event.set()
self.notify_done_future(job)
if remove_job is not None:
remove_job()
if on_job_finished is not None:
on_job_finished()
if notify_waiters is not None:
notify_waiters()
self.runtime_complete(runtime_request_id, runtime_extra)
def fail_job(
self,
job: SchedulerPendingJob,
*,
error: str,
remove_job: Callable[[], None] | None = None,
on_job_finished: Callable[[], None] | None = None,
notify_waiters: Callable[[], None] | None = None,
) -> None:
job.error = str(error)
job.done_event.set()
self.notify_done_future(job)
if remove_job is not None:
remove_job()
if on_job_finished is not None:
on_job_finished()
if notify_waiters is not None:
notify_waiters()
self.runtime_fail(job.engine_request_id, str(error))
def complete_finalize_task(
self,
*,
condition: threading.Condition,
job_registry: SchedulerJobRegistry,
job: SchedulerPendingJob,
item: T2SFinishedItem,
sample_rate: int,
audio_data: np.ndarray,
) -> None:
runtime_extra: Optional[Dict[str, Any]] = None
with condition:
if job_registry.get(item.request_id) is not job:
return
self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate)
self.complete_job(
job,
runtime_request_id=job.engine_request_id,
runtime_extra=runtime_extra,
on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id),
notify_waiters=condition.notify_all,
)
def fail_jobs(
self,
*,
condition: threading.Condition,
job_registry: SchedulerJobRegistry,
request_ids: List[str],
error: str,
) -> None:
if not request_ids:
return
with condition:
for request_id in request_ids:
job = job_registry.get(request_id)
if job is None:
continue
self.fail_job(
job,
error=error,
on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid),
)
condition.notify_all()

View File

@ -0,0 +1,430 @@
from __future__ import annotations
import threading
import time
from typing import Any, Callable, Dict, List, Optional
import torch
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
T2SActiveBatch,
T2SFinishedItem,
decode_one_step,
merge_active_batches,
run_prefill_active_batch,
)
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerPendingJob
class WorkerDecodeExecutor:
def __init__(self, tts: TTS, max_steps: int) -> None:
self.tts = tts
self.max_steps = int(max_steps)
def _sync_device(self) -> None:
try:
device_str = str(self.tts.configs.device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(self.tts.configs.device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
def execute_prefill_merge(
self,
*,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None],
add_prefill_time: Callable[[List[str], float], None] | None,
add_merge_time: Callable[[List[str], float], None] | None,
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
finalize_error: Callable[[List[str], str], None] | None,
) -> Dict[str, Any]:
if not pending_jobs:
return {
"executed": False,
"active_batch": active_batch,
"pending_jobs": [],
"prefill_elapsed_s": 0.0,
"merge_elapsed_s": 0.0,
"finished_items": [],
"error": None,
"error_request_ids": [],
}
admitted_finished: List[T2SFinishedItem] = []
prefill_elapsed_s = 0.0
merge_elapsed_s = 0.0
error: str | None = None
error_request_ids: List[str] = []
try:
self._sync_device()
prefill_start = time.perf_counter()
mark_prefill_started(pending_jobs, prefill_start)
admitted_active_batch, admitted_finished = run_prefill_active_batch(
self.tts.t2s_model.model,
[job.state for job in pending_jobs],
max_steps=self.max_steps,
)
self._sync_device()
prefill_elapsed_s = time.perf_counter() - prefill_start
if add_prefill_time is not None:
add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s)
if enqueue_finished is not None:
enqueue_finished(admitted_finished)
merge_start = time.perf_counter()
active_batch = merge_active_batches(
self.tts.t2s_model.model,
active_batch,
admitted_active_batch,
)
merge_elapsed_s = time.perf_counter() - merge_start
if add_merge_time is not None:
add_merge_time(
[] if active_batch is None else list(active_batch.request_ids),
merge_elapsed_s,
)
except Exception as exc:
error = str(exc)
error_request_ids = [job.request_id for job in pending_jobs]
if finalize_error is not None:
finalize_error(error_request_ids, error)
return {
"executed": True,
"active_batch": active_batch,
"pending_jobs": list(pending_jobs),
"prefill_elapsed_s": float(prefill_elapsed_s),
"merge_elapsed_s": float(merge_elapsed_s),
"finished_items": list(admitted_finished),
"error": error,
"error_request_ids": error_request_ids,
}
def execute_decode_step(
self,
*,
active_batch: Optional[T2SActiveBatch],
add_decode_time: Callable[[List[str], float], None] | None,
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
finalize_error: Callable[[List[str], str], None] | None,
) -> Dict[str, Any]:
if active_batch is None:
return {
"executed": False,
"active_batch": None,
"request_ids": [],
"decode_elapsed_s": 0.0,
"finished_items": [],
"error": None,
"error_request_ids": [],
}
active_request_ids: List[str] = []
step_finished: List[T2SFinishedItem] = []
decode_elapsed_s = 0.0
error: str | None = None
error_request_ids: List[str] = []
try:
active_request_ids = [state.request_id for state in active_batch.states]
self._sync_device()
decode_start = time.perf_counter()
active_batch, step_finished = decode_one_step(
self.tts.t2s_model.model,
active_batch,
max_steps=self.max_steps,
)
self._sync_device()
decode_elapsed_s = time.perf_counter() - decode_start
if add_decode_time is not None:
add_decode_time(active_request_ids, decode_elapsed_s)
if enqueue_finished is not None:
enqueue_finished(step_finished)
except Exception as exc:
error = str(exc)
error_request_ids = list(active_request_ids)
if finalize_error is not None:
finalize_error(error_request_ids, error)
active_batch = None
return {
"executed": True,
"active_batch": active_batch,
"request_ids": active_request_ids,
"decode_elapsed_s": float(decode_elapsed_s),
"finished_items": list(step_finished),
"error": error,
"error_request_ids": error_request_ids,
}
def execute_decode_cycle(
self,
*,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None],
add_prefill_time: Callable[[List[str], float], None] | None,
add_merge_time: Callable[[List[str], float], None] | None,
add_decode_time: Callable[[List[str], float], None] | None,
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
finalize_error: Callable[[List[str], str], None] | None,
) -> Dict[str, Any]:
result = {
"executed": False,
"prefill_merge_executed": False,
"decode_step_executed": False,
"active_batch": active_batch,
"prefill_phase": {},
"decode_phase": {},
}
prefill_phase = self.execute_prefill_merge(
pending_jobs=list(pending_jobs),
active_batch=result["active_batch"],
mark_prefill_started=mark_prefill_started,
add_prefill_time=add_prefill_time,
add_merge_time=add_merge_time,
enqueue_finished=enqueue_finished,
finalize_error=finalize_error,
)
prefill_executed = bool(prefill_phase.get("executed", False))
result["prefill_phase"] = prefill_phase
result["active_batch"] = prefill_phase.get("active_batch")
if prefill_executed:
result["executed"] = True
result["prefill_merge_executed"] = True
decode_phase = self.execute_decode_step(
active_batch=result["active_batch"],
add_decode_time=add_decode_time,
enqueue_finished=enqueue_finished,
finalize_error=finalize_error,
)
decode_executed = bool(decode_phase.get("executed", False))
result["decode_phase"] = decode_phase
result["active_batch"] = decode_phase.get("active_batch")
if decode_executed:
result["executed"] = True
result["decode_step_executed"] = True
return result
class WorkerDecodeLegacyShell:
def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None:
self.condition = condition
self.micro_batch_wait_s = float(micro_batch_wait_s)
self.pending_jobs: List[SchedulerPendingJob] = []
self.active_batch: T2SActiveBatch | None = None
@staticmethod
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None:
if active_batch is None:
return None
return {
"request_count": int(len(active_batch.request_ids)),
"request_ids": list(active_batch.request_ids),
"prefill_done": bool(active_batch.prefill_done),
"decode_step_index_max": (
int(active_batch.step_indices.max().item())
if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0
else 0
),
}
def current_backlog_locked(self) -> int:
running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids)
return int(len(self.pending_jobs) + running_requests)
def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None:
self.pending_jobs.append(job)
def snapshot_locked(self) -> Dict[str, Any]:
active_batch_summary = self._summarize_active_batch(self.active_batch)
executor_local_pending_jobs = int(len(self.pending_jobs))
executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids))
executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None)
return {
"executor_local_pending_jobs": executor_local_pending_jobs,
"executor_local_running_requests": executor_local_running_requests,
"executor_local_has_work": executor_local_has_work,
"executor_local_active_batch": active_batch_summary,
}
def is_idle_locked(self) -> bool:
return self.active_batch is None and not self.pending_jobs
def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
with self.condition:
if not self.pending_jobs and self.active_batch is None:
self.condition.wait(timeout=self.micro_batch_wait_s)
elif wait_for_batch and self.pending_jobs:
self.condition.wait(timeout=self.micro_batch_wait_s)
if not self.pending_jobs:
return []
pending = list(self.pending_jobs)
self.pending_jobs.clear()
return pending
def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
with self.condition:
if not self.pending_jobs:
return []
if wait_for_batch:
oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time)
if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s:
return []
pending = list(self.pending_jobs)
self.pending_jobs.clear()
return pending
def has_decode_runtime_work(self) -> bool:
with self.condition:
return bool(self.pending_jobs or self.active_batch is not None)
def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]:
active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids)
decode_step_index_max = 0
prefill_done = False
if self.active_batch is not None:
prefill_done = bool(self.active_batch.prefill_done)
if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0:
decode_step_index_max = int(self.active_batch.step_indices.max().item())
return {
"pending_jobs": int(len(self.pending_jobs)),
"active_request_count": int(len(active_request_ids)),
"active_request_ids": active_request_ids[:32],
"prefill_done": bool(prefill_done),
"decode_step_index_max": int(decode_step_index_max),
"total_cycles": int(total_cycles),
"prefill_cycles": int(prefill_cycles),
"step_cycles": int(step_cycles),
"has_work": bool(self.pending_jobs or self.active_batch is not None),
"last_event": str(last_event),
"updated_at": float(time.perf_counter()),
}
def run_prefill_merge_once_nonblocking(
self,
*,
external_pending_jobs: Optional[List[SchedulerPendingJob]],
external_active_batch: Optional[T2SActiveBatch],
execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]],
) -> Dict[str, Any]:
pending_jobs = (
list(external_pending_jobs)
if external_pending_jobs is not None
else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None)
)
active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch
result = execute_prefill_merge(pending_jobs, active_batch)
if external_pending_jobs is None:
with self.condition:
self.active_batch = result.get("active_batch")
self.condition.notify_all()
return result
def run_decode_step_once_nonblocking(
self,
*,
external_active_batch: Optional[T2SActiveBatch],
execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]],
) -> Dict[str, Any]:
active_batch = self.active_batch if external_active_batch is None else external_active_batch
result = execute_decode_step(active_batch)
if external_active_batch is None:
with self.condition:
self.active_batch = result.get("active_batch")
self.condition.notify_all()
return result
def run_decode_cycle_nonblocking(
self,
*,
external_pending_jobs: Optional[List[SchedulerPendingJob]],
external_active_batch: Optional[T2SActiveBatch],
execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]],
on_cycle_executed: Callable[[Dict[str, Any]], None] | None,
) -> Dict[str, Any]:
pending_jobs = (
list(external_pending_jobs)
if external_pending_jobs is not None
else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None)
)
active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch
result = execute_decode_cycle(pending_jobs, active_batch)
if external_pending_jobs is None:
with self.condition:
self.active_batch = result.get("active_batch")
self.condition.notify_all()
if result.get("executed") and on_cycle_executed is not None:
on_cycle_executed(result)
return result
def run_loop(
self,
*,
run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]],
) -> None:
while True:
executed = run_decode_cycle_nonblocking()
if executed.get("executed"):
continue
wait_for_batch = self.active_batch is None
pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch)
if pending_jobs:
with self.condition:
self.pending_jobs = pending_jobs + self.pending_jobs
self.condition.notify_all()
continue
time.sleep(self.micro_batch_wait_s)
class WorkerDecodeRuntimeTracker:
def __init__(
self,
runtime_callbacks: RuntimeStateCallbacks | None = None,
) -> None:
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
self.total_cycles = 0
self.prefill_cycles = 0
self.step_cycles = 0
def get_counters(self) -> Dict[str, int]:
return {
"total_cycles": int(self.total_cycles),
"prefill_cycles": int(self.prefill_cycles),
"step_cycles": int(self.step_cycles),
}
def record_cycle(self, result: Dict[str, Any]) -> None:
if not bool(result.get("executed")):
return
self.total_cycles += 1
if bool(result.get("prefill_merge_executed")):
self.prefill_cycles += 1
if bool(result.get("decode_step_executed")):
self.step_cycles += 1
def build_runtime_summary_locked(
self,
*,
legacy_shell: WorkerDecodeLegacyShell,
last_event: str,
) -> Dict[str, Any]:
return legacy_shell.build_runtime_summary_locked(
total_cycles=int(self.total_cycles),
prefill_cycles=int(self.prefill_cycles),
step_cycles=int(self.step_cycles),
last_event=str(last_event),
)
def notify_runtime_update_locked(
self,
*,
legacy_shell: WorkerDecodeLegacyShell,
last_event: str,
) -> None:
if self.runtime_callbacks.decode_runtime_update is None:
return
snapshot = self.build_runtime_summary_locked(
legacy_shell=legacy_shell,
last_event=last_event,
)
self.runtime_callbacks.decode_runtime_update(snapshot)

View File

@ -0,0 +1,164 @@
from __future__ import annotations
import time
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class WorkerExecutionMixin:
def execute_prefill_merge(
self,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
return self.decode_executor.execute_prefill_merge(
pending_jobs=pending_jobs,
active_batch=active_batch,
mark_prefill_started=self._mark_prefill_started,
add_prefill_time=None if external_bookkeeping else self._add_prefill_time,
add_merge_time=None if external_bookkeeping else self._add_merge_time,
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
finalize_error=None if external_bookkeeping else self._finalize_error,
)
def execute_decode_step(
self,
active_batch: Optional[T2SActiveBatch],
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
return self.decode_executor.execute_decode_step(
active_batch=active_batch,
add_decode_time=None if external_bookkeeping else self._add_decode_time,
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
finalize_error=None if external_bookkeeping else self._finalize_error,
)
def execute_decode_cycle(
self,
pending_jobs: List[SchedulerPendingJob],
active_batch: Optional[T2SActiveBatch],
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_executor.execute_decode_cycle(
pending_jobs=pending_jobs,
active_batch=active_batch,
mark_prefill_started=self._mark_prefill_started,
add_prefill_time=None if external_bookkeeping else self._add_prefill_time,
add_merge_time=None if external_bookkeeping else self._add_merge_time,
add_decode_time=None if external_bookkeeping else self._add_decode_time,
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
finalize_error=None if external_bookkeeping else self._finalize_error,
)
self._record_decode_runtime_cycle(result)
return result
def run_prefill_merge_once_nonblocking(
self,
external_pending_jobs: Optional[List[SchedulerPendingJob]] = None,
external_active_batch: Optional[T2SActiveBatch] = None,
emit_runtime_state: bool = True,
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking(
external_pending_jobs=external_pending_jobs,
external_active_batch=external_active_batch,
execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge(
pending_jobs=batch_jobs,
active_batch=batch_state,
external_bookkeeping=external_bookkeeping,
),
)
if emit_runtime_state:
self._notify_decode_runtime_state("prefill_merge")
return result
def run_decode_step_once_nonblocking(
self,
external_active_batch: Optional[T2SActiveBatch] = None,
emit_runtime_state: bool = True,
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_legacy_shell.run_decode_step_once_nonblocking(
external_active_batch=external_active_batch,
execute_decode_step=lambda batch_state: self.execute_decode_step(
active_batch=batch_state,
external_bookkeeping=external_bookkeeping,
),
)
if emit_runtime_state:
self._notify_decode_runtime_state("decode_step")
return result
def run_decode_cycle_nonblocking(
self,
external_pending_jobs: Optional[List[SchedulerPendingJob]] = None,
external_active_batch: Optional[T2SActiveBatch] = None,
emit_runtime_state: bool = True,
external_bookkeeping: bool = False,
) -> Dict[str, Any]:
result = self.decode_legacy_shell.run_decode_cycle_nonblocking(
external_pending_jobs=external_pending_jobs,
external_active_batch=external_active_batch,
execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle(
pending_jobs=batch_jobs,
active_batch=batch_state,
external_bookkeeping=external_bookkeeping,
),
on_cycle_executed=None,
)
if result.get("executed") and emit_runtime_state:
self._notify_decode_runtime_state("decode_cycle")
return result
def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
try:
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
with self.condition:
for task in tasks:
job = self.job_registry.get(task.request_id)
if job is None:
continue
jobs_and_items.append((job, task.item))
if not jobs_and_items:
return
now = time.perf_counter()
for task in tasks:
self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0))
for job, item in jobs_and_items:
self._runtime_update(
job.engine_request_id,
EngineStatus.FINALIZING,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
},
)
synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items)
with self.condition:
for job, _ in jobs_and_items:
tracked_job = self.job_registry.get(job.request_id)
if tracked_job is not None:
tracked_job.synth_ms += synth_ms
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data)
except Exception as exc:
self._finalize_error([task.request_id for task in tasks], str(exc))
finally:
self.finalize_executor.end_execution(len(tasks))
def _run_finalize_loop(self) -> None:
while True:
tasks = self.finalize_executor.take_task_batch_blocking()
self.execute_finalize_tasks(tasks)
def _run_loop(self) -> None:
self.decode_legacy_shell.run_loop(
run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking()
)

View File

@ -0,0 +1,251 @@
from __future__ import annotations
import os
import threading
import time
from collections import deque
from typing import Any, Callable, Deque, Dict, List
import numpy as np
import torch
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import SchedulerFinalizeTask, SchedulerPendingJob
class WorkerFinalizeExecutor:
def __init__(
self,
tts: TTS,
on_state_change: Callable[[], None] | None = None,
external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None,
) -> None:
self.tts = tts
self.on_state_change = on_state_change
self.external_submit = external_submit
self.condition = threading.Condition()
self.pending_tasks: Deque[SchedulerFinalizeTask] = deque()
self.pending_peak = 0
self.inflight = 0
self.inflight_peak = 0
self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1)))
self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower()
self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16)))
self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0)
def _notify_state_change(self) -> None:
if self.on_state_change is None:
return
try:
self.on_state_change()
except Exception:
pass
def get_worker_count(self) -> int:
return int(self.worker_count)
def get_batch_policy(self) -> Dict[str, Any]:
return {
"finalize_mode": str(self.finalize_mode),
"finalize_batch_max_items": int(self.batch_max_items),
"finalize_batch_wait_s": float(self.batch_wait_s),
}
def get_pending_count(self) -> int:
with self.condition:
return int(len(self.pending_tasks))
def snapshot(self) -> Dict[str, Any]:
with self.condition:
return {
"finalize_pending": int(len(self.pending_tasks)),
"finalize_pending_peak": int(self.pending_peak),
"finalize_inflight": int(self.inflight),
"finalize_inflight_peak": int(self.inflight_peak),
"finalize_workers": int(self.worker_count),
"finalize_mode": str(self.finalize_mode),
"finalize_batch_max_items": int(self.batch_max_items),
"finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0),
}
def is_idle(self) -> bool:
with self.condition:
return self.inflight <= 0 and not self.pending_tasks
def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None:
if not tasks:
return
if self.external_submit is not None:
self.external_submit(tasks)
self._notify_state_change()
return
with self.condition:
for task in tasks:
self.pending_tasks.append(task)
self.pending_peak = max(self.pending_peak, len(self.pending_tasks))
self.condition.notify_all()
self._notify_state_change()
def begin_execution(self, task_count: int) -> None:
if task_count <= 0:
return
with self.condition:
self.inflight += int(task_count)
self.inflight_peak = max(self.inflight_peak, self.inflight)
self.condition.notify_all()
self._notify_state_change()
def end_execution(self, task_count: int) -> None:
with self.condition:
self.inflight = max(0, self.inflight - int(task_count))
self.condition.notify_all()
self._notify_state_change()
def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
selected_tasks = [self.pending_tasks.popleft()]
if self.finalize_mode == "sync" or self.tts.configs.use_vocoder:
self.inflight += len(selected_tasks)
self.inflight_peak = max(self.inflight_peak, self.inflight)
self._notify_state_change()
return selected_tasks
batch_deadline = time.perf_counter() + self.batch_wait_s
while len(selected_tasks) < self.batch_max_items:
if not self.pending_tasks:
remaining = batch_deadline - time.perf_counter()
if remaining <= 0:
break
self.condition.wait(timeout=remaining)
continue
first_task = selected_tasks[0]
matched_index = None
for index, task in enumerate(self.pending_tasks):
if abs(task.enqueued_time - first_task.enqueued_time) < 1.0:
matched_index = index
break
if matched_index is not None:
selected_tasks.append(self.pending_tasks[matched_index])
del self.pending_tasks[matched_index]
continue
remaining = batch_deadline - time.perf_counter()
if remaining <= 0:
break
self.condition.wait(timeout=remaining)
self.inflight += len(selected_tasks)
self.inflight_peak = max(self.inflight_peak, self.inflight)
self._notify_state_change()
return selected_tasks
def _sync_device(self) -> None:
try:
device_str = str(self.tts.configs.device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(self.tts.configs.device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
@staticmethod
def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]:
refer_specs = []
if job.state.refer_spec is not None:
refer_specs.append(job.state.refer_spec)
refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or []))
return refer_specs
def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]:
audio_fragment = self.tts.synthesize_audio_request_local(
semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0),
phones=job.state.phones.detach().clone().unsqueeze(0),
prompt_semantic=job.state.prompt_semantic.detach().clone(),
prompt_phones=job.state.prompt_phones.detach().clone(),
refer_spec=[
(
refer_spec_item[0].detach().clone(),
None if refer_spec_item[1] is None else refer_spec_item[1].detach().clone(),
)
for refer_spec_item in self._collect_job_refer_specs(job)
],
raw_audio=job.state.raw_audio.detach().clone(),
raw_sr=int(job.state.raw_sr),
speed=float(job.speed_factor),
sample_steps=int(job.sample_steps),
)
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
return self.tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=float(job.speed_factor),
split_bucket=False,
fragment_interval=0.0,
super_sampling=bool(job.super_sampling),
)
def _synthesize_finished_audio_batch(
self,
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
) -> List[tuple[int, np.ndarray]]:
semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items]
phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items]
refer_specs = []
speeds = []
sample_steps_list = []
for job, _ in jobs_and_items:
refer_spec_group = self._collect_job_refer_specs(job)
if len(refer_spec_group) != 1:
raise ValueError("batched finalize 暂不支持单请求多参考音频")
refer_specs.append(
[(
refer_spec_group[0][0].detach().clone(),
None if refer_spec_group[0][1] is None else refer_spec_group[0][1].detach().clone(),
)]
)
speeds.append(float(job.speed_factor))
sample_steps_list.append(int(job.sample_steps))
audio_fragments = self.tts.synthesize_audio_requests_local_batched(
semantic_tokens_list=semantic_tokens_list,
phones_list=phones_list,
refer_specs=refer_specs,
speeds=speeds,
sample_steps_list=sample_steps_list,
)
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
results: List[tuple[int, np.ndarray]] = []
for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments):
results.append(
self.tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=float(job.speed_factor),
split_bucket=False,
fragment_interval=0.0,
super_sampling=bool(job.super_sampling),
)
)
return results
def synthesize_finalize_jobs(
self,
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
) -> tuple[float, List[tuple[int, np.ndarray]]]:
if not jobs_and_items:
return 0.0, []
self._sync_device()
synth_start = time.perf_counter()
if (
len(jobs_and_items) == 1
or self.tts.configs.use_vocoder
or any(len(self._collect_job_refer_specs(job)) != 1 for job, _ in jobs_and_items)
):
batch_results = [self._synthesize_finished_audio(job, item) for job, item in jobs_and_items]
else:
batch_results = self._synthesize_finished_audio_batch(jobs_and_items)
self._sync_device()
synth_ms = (time.perf_counter() - synth_start) * 1000.0
return float(synth_ms), batch_results

View File

@ -0,0 +1,140 @@
from __future__ import annotations
import asyncio
import os
import time
from typing import Any, Callable, Dict, List
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
class WorkerPrepareExecutor:
def __init__(
self,
tts: TTS,
on_state_change: Callable[[], None] | None = None,
) -> None:
self.coordinator = PrepareCoordinator(tts)
self.on_state_change = on_state_change
def _notify_state_change(self) -> None:
if self.on_state_change is None:
return
try:
self.on_state_change()
except Exception:
pass
def snapshot(self) -> Dict[str, int]:
return dict(self.coordinator.snapshot())
def get_max_inflight(self) -> int:
return int(self.coordinator.snapshot().get("max_inflight", 0))
def get_batch_policy(self) -> Dict[str, int]:
return {
"prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))),
}
def is_idle(self) -> bool:
return int(self.coordinator.snapshot().get("inflight", 0)) <= 0
async def prepare_state_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> tuple[T2SRequestState, float, float]:
try:
return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at)
finally:
self._notify_state_change()
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
results = await asyncio.gather(
*[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs]
)
return [state for state, _, _ in results]
async def prepare_cpu_stage_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> PreparedCpuStage:
try:
return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
finally:
self._notify_state_change()
async def prepare_gpu_stage_profiled_async(
self,
cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]:
try:
return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage)
finally:
self._notify_state_change()
async def prepare_gpu_stages_profiled_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]:
try:
return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages)
finally:
self._notify_state_change()
async def prepare_gpu_audio_phases_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[Dict[str, Any] | Exception]:
try:
return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages)
finally:
self._notify_state_change()
async def prepare_gpu_text_phases_async(
self,
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> List[Dict[str, Any] | Exception]:
try:
return await self.coordinator.prepare_gpu_text_phases_async(items)
finally:
self._notify_state_change()
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
try:
return self.coordinator.build_gpu_prepare_result_from_phases(
cpu_stage,
phase_one,
phase_two,
extra_profile=extra_profile,
)
finally:
self._notify_state_change()
async def prepare_ref_spec_stages_async(
self,
phase_ones: List[Dict[str, Any]],
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
try:
return await self.coordinator.prepare_ref_spec_stages_async(phase_ones)
finally:
self._notify_state_change()
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
try:
self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result)
finally:
self._notify_state_change()

View File

@ -0,0 +1,170 @@
from __future__ import annotations
import threading
import time
from typing import Any, Dict, List, Optional
import numpy as np
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
class WorkerRuntimeBookkeepingMixin:
def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None:
with self.condition:
for job in pending_jobs:
job.first_schedule_time = float(started_at)
self._runtime_update(
job.engine_request_id,
EngineStatus.GPU_PREPARING,
{"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)},
)
def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
if not request_ids:
return
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
job.prefill_ms += delta_ms
def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
if not request_ids:
return
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
job.merge_ms += delta_ms
def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
delta_ms = float(elapsed_s) * 1000.0
if not request_ids:
return
activate_request_ids: List[str] = []
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
if job.decode_steps == 0:
activate_request_ids.append(job.engine_request_id)
job.decode_ms += delta_ms
job.decode_steps += 1
for engine_request_id in activate_request_ids:
self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None:
if not request_ids:
return
with self.condition:
for request_id in request_ids:
job = self.job_registry.get(request_id)
if job is not None:
job.finalize_wait_ms += float(delta_ms)
def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None:
if not items:
return
enqueued_at = time.perf_counter()
tasks: List[SchedulerFinalizeTask] = []
with self.condition:
for item in items:
job = self.job_registry.get(item.request_id)
if job is not None:
self._runtime_update(
job.engine_request_id,
EngineStatus.READY_FOR_FINALIZE,
{
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
},
)
tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at))
self.finalize_executor.enqueue_tasks(tasks)
def begin_finalize_execution(self, task_count: int) -> None:
self.finalize_executor.begin_execution(task_count)
def end_finalize_execution(self, task_count: int) -> None:
self.finalize_executor.end_execution(task_count)
def record_external_job_done(self, request_id: str) -> None:
with self.condition:
self.job_registry.mark_finished_and_remove(request_id)
self.condition.notify_all()
def synthesize_finalize_jobs(
self,
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
) -> tuple[float, List[tuple[int, np.ndarray]]]:
return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items)
def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None:
self.completion_bridge.complete_finalize_task(
condition=self.condition,
job_registry=self.job_registry,
job=job,
item=item,
sample_rate=sample_rate,
audio_data=audio_data,
)
def _finalize_error(self, request_ids: List[str], error: str) -> None:
self.completion_bridge.fail_jobs(
condition=self.condition,
job_registry=self.job_registry,
request_ids=request_ids,
error=error,
)
@staticmethod
def _resolve_done_future(job: SchedulerPendingJob) -> None:
future = job.done_future
if future is None or future.done():
return
future.set_result(job)
def _notify_done_future(self, job: SchedulerPendingJob) -> None:
self.completion_bridge.notify_done_future(job)
def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
if request_id is None or self.runtime_callbacks.update is None:
return
self.runtime_callbacks.update(request_id, status, extra)
def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None:
self.completion_bridge.runtime_complete(request_id, extra)
def _runtime_fail(self, request_id: str | None, error: str) -> None:
self.completion_bridge.runtime_fail(request_id, error)
def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]:
return self.decode_runtime_tracker.build_runtime_summary_locked(
legacy_shell=self.decode_legacy_shell,
last_event=str(last_event),
)
def _notify_decode_runtime_state(self, last_event: str) -> None:
with self.condition:
self.decode_runtime_tracker.notify_runtime_update_locked(
legacy_shell=self.decode_legacy_shell,
last_event=str(last_event),
)
def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None:
with self.condition:
self.decode_runtime_tracker.record_cycle(result)
def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch)
def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch)
def has_decode_runtime_work(self) -> bool:
return self.decode_legacy_shell.has_decode_runtime_work()

View File

@ -0,0 +1,308 @@
from __future__ import annotations
import asyncio
import threading
import time
from typing import Any, Dict, List
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerPendingJob
class WorkerSubmitLifecycleMixin:
def _current_decode_backlog_locked(self) -> int:
return self.decode_legacy_shell.current_backlog_locked()
def get_micro_batch_wait_s(self) -> float:
return float(self.micro_batch_wait_s)
def is_engine_decode_control_enabled(self) -> bool:
return bool(self.engine_decode_control_enabled)
def get_prepare_max_inflight(self) -> int:
return int(self.prepare_executor.get_max_inflight())
def get_capacity_limits(self) -> Dict[str, int]:
return {
"decode_backlog_max": int(self.decode_backlog_max),
"finalize_pending_max": int(self.finalize_pending_max),
}
def get_finalize_batch_policy(self) -> Dict[str, Any]:
return dict(self.finalize_executor.get_batch_policy())
def get_prepare_batch_policy(self) -> Dict[str, int]:
return dict(self.prepare_executor.get_batch_policy())
def get_decode_runtime_counters(self) -> Dict[str, int]:
with self.condition:
return self.decode_runtime_tracker.get_counters()
def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]:
decode_backlog = self._current_decode_backlog_locked()
finalize_pending = int(self.finalize_executor.get_pending_count())
prepare_inflight = int(self.prepare_executor.snapshot()["inflight"])
blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max
blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max
return (
not blocked_decode and not blocked_finalize,
{
"decode_backlog": decode_backlog,
"finalize_pending": finalize_pending,
"prepare_inflight": prepare_inflight,
"decode_backlog_max": int(self.decode_backlog_max),
"finalize_pending_max": int(self.finalize_pending_max),
},
)
def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]:
start = time.perf_counter()
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
while True:
with self.condition:
allowed, snapshot = self._can_accept_submit_locked()
if allowed:
return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot
if deadline is not None and time.perf_counter() >= deadline:
raise TimeoutError(
"scheduler submit admission timeout "
f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})"
)
self.condition.wait(timeout=self.micro_batch_wait_s)
def _admission_snapshot_locked(self) -> Dict[str, int]:
_, snapshot = self._can_accept_submit_locked()
return snapshot
async def submit_async(
self,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
super_sampling: bool,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None = None,
done_future: asyncio.Future | None = None,
engine_request_id: str | None = None,
timeout_sec: float | None = None,
skip_capacity_wait: bool = False,
admission_wait_ms_override: float | None = None,
admission_snapshot_override: Dict[str, Any] | None = None,
engine_policy_wait_ms: float = 0.0,
engine_dispatch_wait_ms: float = 0.0,
enqueue_pending: bool = True,
) -> SchedulerPendingJob:
return await asyncio.to_thread(
self.submit,
state,
speed_factor,
sample_steps,
media_type,
super_sampling,
prepare_wall_ms,
prepare_profile_total_ms,
done_loop,
done_future,
engine_request_id,
timeout_sec,
skip_capacity_wait,
admission_wait_ms_override,
admission_snapshot_override,
engine_policy_wait_ms,
engine_dispatch_wait_ms,
enqueue_pending,
)
def snapshot(self) -> dict:
with self.condition:
prepare_state = self.prepare_executor.snapshot()
finalize_state = self.finalize_executor.snapshot()
shell_state = self.decode_legacy_shell.snapshot_locked()
decode_runtime_counters = self.decode_runtime_tracker.get_counters()
engine_owned_decode_state = bool(self.engine_decode_control_enabled)
active_batch_summary = shell_state.get("executor_local_active_batch")
executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0))
executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0))
executor_local_has_work = bool(shell_state.get("executor_local_has_work", False))
return {
"pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs,
"running_requests": 0 if engine_owned_decode_state else executor_local_running_requests,
"engine_decode_control_enabled": bool(self.engine_decode_control_enabled),
"legacy_state_owner_mode": not engine_owned_decode_state,
"decode_state_owner": "engine" if engine_owned_decode_state else "worker",
"decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work,
"executor_local_pending_jobs": executor_local_pending_jobs,
"executor_local_running_requests": executor_local_running_requests,
"executor_local_has_work": executor_local_has_work,
"decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)),
"decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)),
"decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)),
"prepare_inflight": prepare_state["inflight"],
"prepare_peak_inflight": prepare_state["peak_inflight"],
"prepare_max_inflight": prepare_state.get("max_inflight", 0),
"prepare_state": dict(prepare_state),
**finalize_state,
"decode_backlog_max": self.decode_backlog_max,
"finalize_pending_max": self.finalize_pending_max,
"active_batch": {} if engine_owned_decode_state else active_batch_summary,
"executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None,
"total_submitted": self.job_registry.submitted_count(),
"total_finished": self.job_registry.finished_count(),
"drained": self.is_drained(),
}
def is_drained(self) -> bool:
with self.condition:
return (
self.decode_legacy_shell.is_idle_locked()
and self.job_registry.is_empty()
and self.prepare_executor.is_idle()
and self.finalize_executor.is_idle()
)
def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool:
deadline = time.perf_counter() + max(0.0, timeout_sec)
while time.perf_counter() < deadline:
if self.is_drained():
return True
time.sleep(poll_interval_sec)
return self.is_drained()
def submit(
self,
state: T2SRequestState,
speed_factor: float,
sample_steps: int,
media_type: str,
super_sampling: bool,
prepare_wall_ms: float,
prepare_profile_total_ms: float,
done_loop: asyncio.AbstractEventLoop | None = None,
done_future: asyncio.Future | None = None,
engine_request_id: str | None = None,
timeout_sec: float | None = None,
skip_capacity_wait: bool = False,
admission_wait_ms_override: float | None = None,
admission_snapshot_override: Dict[str, Any] | None = None,
engine_policy_wait_ms: float = 0.0,
engine_dispatch_wait_ms: float = 0.0,
enqueue_pending: bool = True,
) -> SchedulerPendingJob:
if skip_capacity_wait:
with self.condition:
admission_snapshot = (
dict(admission_snapshot_override)
if admission_snapshot_override is not None
else dict(self._admission_snapshot_locked())
)
admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override)
else:
admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec)
job = SchedulerPendingJob(
request_id=state.request_id,
state=state,
done_event=threading.Event(),
done_loop=done_loop,
done_future=done_future,
enqueue_time=time.perf_counter(),
speed_factor=float(speed_factor),
sample_steps=int(sample_steps),
media_type=media_type,
super_sampling=bool(super_sampling),
admission_wait_ms=float(admission_wait_ms),
engine_policy_wait_ms=float(engine_policy_wait_ms),
engine_dispatch_wait_ms=float(engine_dispatch_wait_ms),
prepare_wall_ms=float(prepare_wall_ms),
prepare_profile_total_ms=float(prepare_profile_total_ms),
engine_request_id=engine_request_id or state.request_id,
)
with self.condition:
self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled)
if enqueue_pending:
self.decode_legacy_shell.enqueue_pending_job_locked(job)
self.condition.notify_all()
if enqueue_pending:
self._notify_decode_runtime_state("submit")
self._runtime_update(
job.engine_request_id,
EngineStatus.QUEUED,
{
"scheduler_request_id": job.request_id,
"decode_admission_wait_ms": float(admission_wait_ms),
"engine_policy_wait_ms": float(engine_policy_wait_ms),
"engine_dispatch_wait_ms": float(engine_dispatch_wait_ms),
"admission_snapshot": dict(admission_snapshot),
},
)
return job
async def prepare_state_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> tuple[T2SRequestState, float, float]:
return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at)
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
return await self.prepare_executor.prepare_states_batch_async(specs)
async def prepare_cpu_stage_profiled_async(
self,
spec: SchedulerRequestSpec,
prepare_submit_at: float,
) -> PreparedCpuStage:
return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
async def prepare_gpu_stage_profiled_async(
self,
cpu_stage: PreparedCpuStage,
) -> tuple[T2SRequestState, float, float]:
return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage)
async def prepare_gpu_stages_profiled_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[tuple[T2SRequestState, float, float] | Exception]:
return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages)
async def prepare_gpu_audio_phases_async(
self,
cpu_stages: List[PreparedCpuStage],
) -> List[Dict[str, Any] | Exception]:
return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages)
async def prepare_gpu_text_phases_async(
self,
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
) -> List[Dict[str, Any] | Exception]:
return await self.prepare_executor.prepare_gpu_text_phases_async(items)
def build_gpu_prepare_result_from_phases(
self,
cpu_stage: PreparedCpuStage,
phase_one: Dict[str, Any],
phase_two: Dict[str, Any],
extra_profile: Dict[str, float] | None = None,
) -> tuple[T2SRequestState, float, float]:
return self.prepare_executor.build_gpu_prepare_result_from_phases(
cpu_stage,
phase_one,
phase_two,
extra_profile=extra_profile,
)
async def prepare_ref_spec_stages_async(
self,
phase_ones: List[Dict[str, Any]],
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones)
def apply_ref_spec_result_to_state(
self,
state: T2SRequestState,
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
) -> None:
self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result)

View File

@ -2,6 +2,7 @@ import warnings
warnings.filterwarnings("ignore")
import math
from typing import List
import torch
from torch import nn
@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def decode_batched_request_local(
self,
codes: torch.Tensor,
code_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
refer_list: List[torch.Tensor],
noise_scale: float = 0.5,
speed: float = 1,
sv_emb: torch.Tensor | None = None,
):
batch_size = int(codes.size(1))
if batch_size <= 0:
raise ValueError("decode_batched_request_local 收到空 batch")
if len(refer_list) != batch_size:
raise ValueError("refer_list 数量与 batch size 不一致")
refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device)
max_refer_len = int(refer_lengths.max().item())
refer_batch = torch.zeros(
(batch_size, int(refer_list[0].size(1)), max_refer_len),
dtype=refer_list[0].dtype,
device=codes.device,
)
for batch_index, refer in enumerate(refer_list):
refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype)
if self.version == "v1":
ge = self.ref_enc(refer_batch * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
if sv_emb is None:
raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb")
ge = ge + self.sv_emb(sv_emb).unsqueeze(-1)
ge = self.prelu(ge)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")
y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2
text_lengths = text_lengths.to(device=text.device, dtype=torch.long)
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
audio = self.dec((z * y_mask)[:, :, :], g=ge)
upsample_factor = 1
for up_layer in self.dec.ups:
stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride)
upsample_factor *= int(stride)
audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor)
return audio, audio_lengths
@torch.no_grad()
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):

View File

@ -1,5 +1,6 @@
import os
import re
import time
import cn2an
from pypinyin import lazy_pinyin, Style
@ -77,6 +78,205 @@ def g2p(text):
return phones, word2ph
def _prepare_g2p_segments(segments):
prepared_segments = []
batch_inputs = []
for segment in segments:
processed_segment = re.sub("[a-zA-Z]+", "", segment)
seg_cut = psg.lcut(processed_segment)
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
prepared_segments.append(
{
"segment": processed_segment,
"seg_cut": seg_cut,
}
)
if processed_segment:
batch_inputs.append(processed_segment)
return prepared_segments, batch_inputs
def _build_segment_from_g2pw(segment: str, seg_cut, pinyins):
phones_list = []
word2ph = []
initials = []
finals = []
pre_word_length = 0
for word, pos in seg_cut:
sub_initials = []
sub_finals = []
now_word_length = pre_word_length + len(word)
if pos == "eng":
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
word_pinyins = correct_pronunciation(word, word_pinyins)
for pinyin in word_pinyins:
if pinyin[0].isalpha():
sub_initials.append(to_initials(pinyin))
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
else:
sub_initials.append(pinyin)
sub_finals.append(pinyin)
pre_word_length = now_word_length
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
for c, v in zip(initials, finals):
raw_pinyin = c + v
if c == v:
assert c in punctuation
phone = [c]
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
phones_list += phone
return phones_list, word2ph
def _build_segment_without_g2pw(segment: str, seg_cut):
initials = []
finals = []
for word, pos in seg_cut:
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
phones_list = []
word2ph = []
for c, v in zip(sum(initials, []), sum(finals, [])):
raw_pinyin = c + v
if c == v:
assert c in punctuation
phone = [c]
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
v_rep_map = {"uei": "ui", "iou": "iu", "uen": "un"}
if v_without_tone in v_rep_map:
pinyin = c + v_rep_map[v_without_tone]
else:
pinyin_rep_map = {"ing": "ying", "i": "yi", "in": "yin", "u": "wu"}
if pinyin in pinyin_rep_map:
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {"v": "yu", "e": "e", "i": "y", "u": "w"}
if pinyin[0] in single_rep_map:
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
phones_list += phone
return phones_list, word2ph
def g2p_segments(segments, return_profile: bool = False):
prepare_start = time.perf_counter()
prepared_segments, batch_inputs = _prepare_g2p_segments(segments)
profile = {
"g2pw_prepare_ms": 0.0,
"g2pw_predict_ms": 0.0,
"g2pw_post_ms": 0.0,
"g2pw_runtime_total_ms": 0.0,
"g2pw_runtime_queue_wait_ms": 0.0,
"g2pw_runtime_collect_wait_ms": 0.0,
"g2pw_runtime_run_ms": 0.0,
"g2pw_runtime_batch_rows": 0.0,
"g2pw_runtime_batch_requests": 0.0,
"g2pw_runtime_pool_workers": 0.0,
"g2pw_runtime_shard_index": 0.0,
}
profile["g2pw_prepare_ms"] = float((time.perf_counter() - prepare_start) * 1000.0)
if is_g2pw and batch_inputs:
converter = g2pw._g2pw
if hasattr(converter, "predict_sentences_with_profile"):
g2pw_batch_results, predict_profile = converter.predict_sentences_with_profile(batch_inputs)
for key, value in dict(predict_profile or {}).items():
profile[key] = float(value)
else:
predict_start = time.perf_counter()
g2pw_batch_results = converter(batch_inputs)
profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_start) * 1000.0)
else:
g2pw_batch_results = []
post_start = time.perf_counter()
results = []
batch_cursor = 0
for item in prepared_segments:
segment = item["segment"]
if not segment:
results.append(([], [], segment))
continue
if not is_g2pw:
phones, word2ph = _build_segment_without_g2pw(segment, item["seg_cut"])
results.append((phones, word2ph, segment))
continue
pinyins = g2pw_batch_results[batch_cursor]
batch_cursor += 1
phones, word2ph = _build_segment_from_g2pw(segment, item["seg_cut"], pinyins)
results.append((phones, word2ph, segment))
profile["g2pw_post_ms"] = float((time.perf_counter() - post_start) * 1000.0)
profile["g2pw_total_ms"] = float(profile["g2pw_prepare_ms"] + profile["g2pw_predict_ms"] + profile["g2pw_post_ms"])
if return_profile:
return results, profile
return results
def _get_initials_finals(word):
initials = []
finals = []
@ -180,118 +380,9 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
def _g2p(segments):
phones_list = []
word2ph = []
for seg in segments:
pinyins = []
# Replace all English words in the sentence
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
initials = []
finals = []
if not is_g2pw:
for word, pos in seg_cut:
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
# 儿化
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])
print("pypinyin结果", initials, finals)
else:
# g2pw采用整句推理
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
pre_word_length = 0
for word, pos in seg_cut:
sub_initials = []
sub_finals = []
now_word_length = pre_word_length + len(word)
if pos == "eng":
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
# 多音字消歧
word_pinyins = correct_pronunciation(word, word_pinyins)
for pinyin in word_pinyins:
if pinyin[0].isalpha():
sub_initials.append(to_initials(pinyin))
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
else:
sub_initials.append(pinyin)
sub_finals.append(pinyin)
pre_word_length = now_word_length
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
# 儿化
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
# print("g2pw结果",initials,finals)
for c, v in zip(initials, finals):
raw_pinyin = c + v
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c == v:
assert c in punctuation
phone = [c]
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
# 多音节
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
# 单音节
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
phones_list += phone
for phones, item_word2ph, _segment in g2p_segments(segments):
phones_list += phones
word2ph += item_word2ph
return phones_list, word2ph

View File

@ -0,0 +1,685 @@
import ctypes
import fcntl
import os
import subprocess
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Deque, Dict, List, Tuple
import numpy as np
from .onnx_api import _G2PWBaseOnnxConverter
class G2PWCudaError(RuntimeError):
pass
@dataclass
class G2PWBatchTask:
model_input: Dict[str, np.ndarray]
created_at: float = field(default_factory=time.perf_counter)
enqueued_at: float = 0.0
done_event: threading.Event = field(default_factory=threading.Event)
output: np.ndarray | None = None
profile: Dict[str, float] = field(default_factory=dict)
error: Exception | None = None
_ROOT_DIR = Path(__file__).resolve().parents[3]
_PACKAGE_DIR = Path(__file__).resolve().parent
_OUTPUT_DIR = _ROOT_DIR / "outputs" / "g2pw_cuda_bridge"
_WRAPPER_SOURCE = _PACKAGE_DIR / "g2pw_cuda_bridge.cpp"
_LOCK_PATH = _OUTPUT_DIR / "build.lock"
def _env_flag(name: str, default: bool) -> int:
raw = os.environ.get(name)
if raw is None:
return 1 if default else 0
return 0 if raw.strip().lower() in {"0", "false", "no", "off"} else 1
def _env_int(name: str, default: int) -> int:
raw = os.environ.get(name)
if raw is None or raw.strip() == "":
return int(default)
return int(raw)
def _resolve_cuda_root() -> Path:
env_root = os.environ.get("GPTSOVITS_G2PW_CUDA_ROOT", "").strip()
candidates = [
env_root,
_ROOT_DIR / "third_party" / "g2pw-cu",
]
for candidate in candidates:
if not candidate:
continue
path = Path(candidate).expanduser().resolve()
if path.exists():
return path
checked = [
str(Path(candidate).expanduser().resolve())
for candidate in candidates
if str(candidate).strip() != ""
]
raise G2PWCudaError(
"Cannot locate g2pw-cu root. "
"Expected one of: "
f"{checked}. "
"Recommended: clone https://github.com/baicai-1145/g2pw-cu.git into "
f"{(_ROOT_DIR / 'third_party' / 'g2pw-cu').as_posix()} "
"or set GPTSOVITS_G2PW_CUDA_ROOT explicitly."
)
def _resolve_runtime_paths() -> tuple[Path, Path, Path]:
cuda_root = _resolve_cuda_root()
runtime_lib = Path(
os.environ.get("GPTSOVITS_G2PW_CUDA_RUNTIME_LIB", str(cuda_root / "build" / "libg2pw_runtime.so"))
).expanduser()
manifest_path = Path(
os.environ.get("GPTSOVITS_G2PW_CUDA_MANIFEST", str(cuda_root / "artifacts" / "model" / "manifest.txt"))
).expanduser()
weights_path = Path(
os.environ.get("GPTSOVITS_G2PW_CUDA_WEIGHTS", str(cuda_root / "artifacts" / "model" / "weights.bin"))
).expanduser()
for path in (runtime_lib, manifest_path, weights_path):
if not path.exists():
raise G2PWCudaError(f"Missing g2pw-cu artifact: {path}")
return runtime_lib.resolve(), manifest_path.resolve(), weights_path.resolve()
def _build_bridge(wrapper_output: Path, runtime_lib: Path) -> None:
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
compile_cmd = [
os.environ.get("CXX", "g++"),
"-O3",
"-std=c++17",
"-shared",
"-fPIC",
str(_WRAPPER_SOURCE),
"-I",
str(runtime_lib.parent.parent / "include"),
"-L",
str(runtime_lib.parent),
"-lg2pw_runtime",
f"-Wl,-rpath,{runtime_lib.parent}",
"-o",
str(wrapper_output),
]
result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False)
if result.returncode != 0:
raise G2PWCudaError(
"Failed to build g2pw-cu bridge:\n"
f"cmd={' '.join(compile_cmd)}\n"
f"stdout={result.stdout}\n"
f"stderr={result.stderr}"
)
def _ensure_bridge_built(runtime_lib: Path) -> Path:
wrapper_output = _OUTPUT_DIR / "g2pw_cuda_bridge.so"
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
with _LOCK_PATH.open("w", encoding="utf-8") as lock_file:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
needs_build = not wrapper_output.exists()
if not needs_build:
so_mtime = wrapper_output.stat().st_mtime
needs_build = so_mtime < _WRAPPER_SOURCE.stat().st_mtime or so_mtime < runtime_lib.stat().st_mtime
if needs_build:
tmp_output = wrapper_output.with_suffix(".tmp.so")
if tmp_output.exists():
tmp_output.unlink()
_build_bridge(tmp_output, runtime_lib)
tmp_output.replace(wrapper_output)
return wrapper_output
def _load_bridge():
runtime_lib, manifest_path, weights_path = _resolve_runtime_paths()
bridge_path = _ensure_bridge_built(runtime_lib)
global_mode = getattr(ctypes, "RTLD_GLOBAL", getattr(os, "RTLD_GLOBAL", 0))
ctypes.CDLL(str(runtime_lib), mode=global_mode)
lib = ctypes.CDLL(str(bridge_path))
lib.g2pw_runtime_create.argtypes = [
ctypes.c_char_p,
ctypes.c_char_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
lib.g2pw_runtime_create.restype = ctypes.c_void_p
lib.g2pw_runtime_destroy.argtypes = [ctypes.c_void_p]
lib.g2pw_runtime_destroy.restype = None
lib.g2pw_runtime_last_error.argtypes = [ctypes.c_void_p]
lib.g2pw_runtime_last_error.restype = ctypes.c_char_p
lib.g2pw_runtime_num_labels.argtypes = [ctypes.c_void_p]
lib.g2pw_runtime_num_labels.restype = ctypes.c_int
lib.g2pw_runtime_run.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int32,
ctypes.c_int32,
ctypes.c_void_p,
]
lib.g2pw_runtime_run.restype = ctypes.c_int
return lib, manifest_path, weights_path, runtime_lib
def _gemm_precision_value() -> int:
precision = os.environ.get("GPTSOVITS_G2PW_CUDA_GEMM_PRECISION", "fp32").strip().lower()
if precision == "fp16":
return 1
if precision == "bf16":
return 2
return 0
class G2PWRuntimeWrapper:
def __init__(self, shard_index: int = 0) -> None:
self.lib, self.manifest_path, self.weights_path, self.runtime_lib = _load_bridge()
self.shard_index = int(shard_index)
self.device_ordinal = _env_int("GPTSOVITS_G2PW_CUDA_DEVICE", 0)
self.allow_tensor_cores = _env_flag("GPTSOVITS_G2PW_CUDA_ALLOW_TENSOR_CORES", False)
self.use_cublaslt_bias_epilogue = _env_flag("GPTSOVITS_G2PW_CUDA_USE_CUBLASLT_BIAS_EPILOGUE", False)
self.enable_profiling = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_PROFILE", False)
self.enable_cuda_graph = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_GRAPH", True)
self.dump_graph_cache_stats = _env_flag("GPTSOVITS_G2PW_CUDA_DUMP_GRAPH_CACHE_STATS", False)
self.full_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_FULL_GRAPH_CACHE_LIMIT", 0)
self.tail_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_TAIL_GRAPH_CACHE_LIMIT", 0)
self.gemm_precision = _gemm_precision_value()
self.lock = threading.Lock()
self.handle = None
self.max_batch_size = 0
self.max_seq_len = 0
self.num_labels = 0
self.batch_enabled = _env_flag("GPTSOVITS_G2PW_CUDA_BATCHING", True) != 0
self.batch_window_s = max(0.0, float(_env_int("GPTSOVITS_G2PW_CUDA_BATCH_WINDOW_MS", 1)) / 1000.0)
self.batch_max_requests = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_REQUESTS", 64))
self.batch_max_rows = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_ROWS", 96))
self.batch_max_tokens = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_TOKENS", 4096))
self.batch_condition = threading.Condition()
self.pending_tasks: Deque[G2PWBatchTask] = deque()
self.batch_total_tasks = 0
self.batch_total_batches = 0
self.batch_total_rows = 0
self.batch_total_queue_wait_ms = 0.0
self.batch_queue_wait_peak_ms = 0.0
self.batch_total_collect_wait_ms = 0.0
self.batch_collect_wait_peak_ms = 0.0
self.batch_total_run_ms = 0.0
self.batch_run_peak_ms = 0.0
self.batch_rows_peak = 0
self.batch_requests_peak = 0
self.batch_pending_peak = 0
self.closed = False
self._ensure_capacity(
batch_size=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_BATCH_SIZE", 256)),
seq_len=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_SEQ_LEN", 128)),
)
self.batch_worker = None
if self.batch_enabled:
self.batch_worker = threading.Thread(
target=self._batch_loop,
name=f"g2pw-cuda-batch-worker-{self.shard_index}",
daemon=True,
)
self.batch_worker.start()
def _sync_runtime_env_overrides(self) -> None:
os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0"
os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0"
os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0"
os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit))
os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit))
os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0"
os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0"
os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32")
def _destroy_handle(self) -> None:
if self.handle:
self.lib.g2pw_runtime_destroy(self.handle)
self.handle = None
def close(self) -> None:
with self.batch_condition:
self.closed = True
self.batch_condition.notify_all()
self._destroy_handle()
def __del__(self):
try:
self.close()
except Exception:
pass
def _last_error(self) -> str:
if not self.handle:
return "uninitialized runtime"
message = self.lib.g2pw_runtime_last_error(self.handle)
return "" if not message else message.decode("utf-8", errors="replace")
def _create_handle(self, batch_size: int, seq_len: int) -> None:
self._sync_runtime_env_overrides()
new_handle = self.lib.g2pw_runtime_create(
str(self.manifest_path).encode("utf-8"),
str(self.weights_path).encode("utf-8"),
int(self.device_ordinal),
int(batch_size),
int(seq_len),
int(self.full_graph_cache_limit),
int(self.tail_graph_cache_limit),
int(self.allow_tensor_cores),
int(self.use_cublaslt_bias_epilogue),
int(self.enable_profiling),
int(self.enable_cuda_graph),
int(self.dump_graph_cache_stats),
int(self.gemm_precision),
)
if not new_handle:
raise G2PWCudaError("g2pw-cu returned null runtime handle")
self.handle = new_handle
self.max_batch_size = int(batch_size)
self.max_seq_len = int(seq_len)
self.num_labels = int(self.lib.g2pw_runtime_num_labels(self.handle))
last_error = self._last_error()
if self.num_labels <= 0 or last_error:
self.close()
raise G2PWCudaError(f"Failed to initialize g2pw-cu runtime: {last_error or 'num_labels <= 0'}")
def _ensure_capacity(self, batch_size: int, seq_len: int) -> None:
target_batch = max(1, int(batch_size))
target_seq = max(1, int(seq_len))
if self.handle and target_batch <= self.max_batch_size and target_seq <= self.max_seq_len:
return
next_batch = max(target_batch, self.max_batch_size * 2 if self.max_batch_size else 0)
next_seq = max(target_seq, self.max_seq_len * 2 if self.max_seq_len else 0)
self._destroy_handle()
self._create_handle(batch_size=next_batch, seq_len=next_seq)
@staticmethod
def _normalize_model_input(model_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
input_ids = np.ascontiguousarray(model_input["input_ids"], dtype=np.int64)
token_type_ids = np.ascontiguousarray(model_input["token_type_ids"], dtype=np.int64)
attention_masks = np.ascontiguousarray(model_input["attention_masks"], dtype=np.int64)
phoneme_masks = np.ascontiguousarray(model_input["phoneme_masks"], dtype=np.float32)
char_ids = np.ascontiguousarray(model_input["char_ids"], dtype=np.int64)
position_ids = np.ascontiguousarray(model_input["position_ids"], dtype=np.int64)
batch_size = int(char_ids.shape[0])
if input_ids.shape[0] == 1 and batch_size > 1:
input_ids = np.ascontiguousarray(np.repeat(input_ids, batch_size, axis=0), dtype=np.int64)
token_type_ids = np.ascontiguousarray(np.repeat(token_type_ids, batch_size, axis=0), dtype=np.int64)
attention_masks = np.ascontiguousarray(np.repeat(attention_masks, batch_size, axis=0), dtype=np.int64)
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_masks": attention_masks,
"phoneme_masks": phoneme_masks,
"char_ids": char_ids,
"position_ids": position_ids,
}
def _run_direct(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
normalized = self._normalize_model_input(model_input)
input_ids = normalized["input_ids"]
token_type_ids = normalized["token_type_ids"]
attention_masks = normalized["attention_masks"]
phoneme_masks = normalized["phoneme_masks"]
char_ids = normalized["char_ids"]
position_ids = normalized["position_ids"]
batch_size = int(char_ids.shape[0])
seq_len = int(input_ids.shape[1])
probs = np.empty((batch_size, self.num_labels), dtype=np.float32)
with self.lock:
self._ensure_capacity(batch_size=batch_size, seq_len=seq_len)
status = self.lib.g2pw_runtime_run(
self.handle,
input_ids.ctypes.data_as(ctypes.c_void_p),
token_type_ids.ctypes.data_as(ctypes.c_void_p),
attention_masks.ctypes.data_as(ctypes.c_void_p),
phoneme_masks.ctypes.data_as(ctypes.c_void_p),
char_ids.ctypes.data_as(ctypes.c_void_p),
position_ids.ctypes.data_as(ctypes.c_void_p),
batch_size,
seq_len,
probs.ctypes.data_as(ctypes.c_void_p),
)
if int(status) != 0:
raise G2PWCudaError(f"g2pw-cu inference failed: {self._last_error()}")
return probs
def _can_append_task(self, tasks: List[G2PWBatchTask], candidate: G2PWBatchTask) -> bool:
request_count = len(tasks) + 1
if request_count > self.batch_max_requests:
return False
total_rows = sum(int(item.model_input["char_ids"].shape[0]) for item in tasks) + int(
candidate.model_input["char_ids"].shape[0]
)
if total_rows > self.batch_max_rows:
return False
total_tokens = sum(
int(item.model_input["char_ids"].shape[0]) * int(item.model_input["input_ids"].shape[1]) for item in tasks
) + int(candidate.model_input["char_ids"].shape[0]) * int(candidate.model_input["input_ids"].shape[1])
return total_tokens <= self.batch_max_tokens
def _merge_batch_inputs(self, tasks: List[G2PWBatchTask]) -> Tuple[Dict[str, np.ndarray], List[Tuple[int, int]]]:
normalized_inputs = [self._normalize_model_input(task.model_input) for task in tasks]
total_rows = sum(int(item["char_ids"].shape[0]) for item in normalized_inputs)
max_seq_len = max(int(item["input_ids"].shape[1]) for item in normalized_inputs)
input_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
token_type_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
attention_masks = np.zeros((total_rows, max_seq_len), dtype=np.int64)
phoneme_masks = np.zeros((total_rows, normalized_inputs[0]["phoneme_masks"].shape[1]), dtype=np.float32)
char_ids = np.zeros((total_rows,), dtype=np.int64)
position_ids = np.zeros((total_rows,), dtype=np.int64)
slices: List[Tuple[int, int]] = []
cursor = 0
for item in normalized_inputs:
rows = int(item["char_ids"].shape[0])
seq_len = int(item["input_ids"].shape[1])
next_cursor = cursor + rows
input_ids[cursor:next_cursor, :seq_len] = item["input_ids"]
token_type_ids[cursor:next_cursor, :seq_len] = item["token_type_ids"]
attention_masks[cursor:next_cursor, :seq_len] = item["attention_masks"]
phoneme_masks[cursor:next_cursor] = item["phoneme_masks"]
char_ids[cursor:next_cursor] = item["char_ids"]
position_ids[cursor:next_cursor] = item["position_ids"]
slices.append((cursor, next_cursor))
cursor = next_cursor
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_masks": attention_masks,
"phoneme_masks": phoneme_masks,
"char_ids": char_ids,
"position_ids": position_ids,
}, slices
def _finish_task(
self,
task: G2PWBatchTask,
output: np.ndarray | None = None,
profile: Dict[str, float] | None = None,
error: Exception | None = None,
) -> None:
task.output = output
task.profile = dict(profile or {})
task.error = error
task.done_event.set()
def _batch_loop(self) -> None:
while True:
with self.batch_condition:
while not self.pending_tasks and not self.closed:
self.batch_condition.wait()
if self.closed and not self.pending_tasks:
return
first_task = self.pending_tasks.popleft()
batch_tasks = [first_task]
collect_started = time.perf_counter()
deadline = collect_started + self.batch_window_s
while True:
if len(batch_tasks) >= self.batch_max_requests:
break
remaining = deadline - time.perf_counter()
if remaining <= 0.0:
break
if not self.pending_tasks:
self.batch_condition.wait(timeout=remaining)
continue
candidate = self.pending_tasks[0]
if not self._can_append_task(batch_tasks, candidate):
break
batch_tasks.append(self.pending_tasks.popleft())
collect_wait_ms = max(0.0, (time.perf_counter() - collect_started) * 1000.0)
now = time.perf_counter()
queue_wait_values = [max(0.0, (now - task.enqueued_at) * 1000.0) for task in batch_tasks]
try:
merged_input, row_slices = self._merge_batch_inputs(batch_tasks)
run_started = time.perf_counter()
merged_output = self._run_direct(merged_input)
run_ms = max(0.0, (time.perf_counter() - run_started) * 1000.0)
for task, (start, end) in zip(batch_tasks, row_slices):
task_rows = int(task.model_input["char_ids"].shape[0])
task_seq_len = int(task.model_input["input_ids"].shape[1])
self._finish_task(
task,
output=np.ascontiguousarray(merged_output[start:end]),
profile={
"g2pw_runtime_queue_wait_ms": float(max(0.0, (run_started - task.enqueued_at) * 1000.0)),
"g2pw_runtime_collect_wait_ms": float(collect_wait_ms),
"g2pw_runtime_run_ms": float(run_ms),
"g2pw_runtime_batch_rows": float(sum(int(item.model_input["char_ids"].shape[0]) for item in batch_tasks)),
"g2pw_runtime_batch_requests": float(len(batch_tasks)),
"g2pw_runtime_task_rows": float(task_rows),
"g2pw_runtime_task_seq_len": float(task_seq_len),
"g2pw_runtime_shard_index": float(self.shard_index),
},
)
except Exception as exc:
run_ms = 0.0
for task in batch_tasks:
self._finish_task(task, error=exc)
finally:
with self.batch_condition:
self.batch_total_batches += 1
self.batch_total_tasks += len(batch_tasks)
self.batch_total_rows += sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
self.batch_total_queue_wait_ms += float(sum(queue_wait_values))
self.batch_queue_wait_peak_ms = max(self.batch_queue_wait_peak_ms, max(queue_wait_values or [0.0]))
self.batch_total_collect_wait_ms += float(collect_wait_ms) * float(len(batch_tasks))
self.batch_collect_wait_peak_ms = max(self.batch_collect_wait_peak_ms, float(collect_wait_ms))
self.batch_total_run_ms += float(run_ms)
self.batch_run_peak_ms = max(self.batch_run_peak_ms, float(run_ms))
self.batch_rows_peak = max(
self.batch_rows_peak, sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
)
self.batch_requests_peak = max(self.batch_requests_peak, len(batch_tasks))
def _submit_batched(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
task = G2PWBatchTask(model_input=model_input)
with self.batch_condition:
if self.closed:
raise G2PWCudaError("g2pw-cu batch worker already closed")
task.enqueued_at = time.perf_counter()
self.pending_tasks.append(task)
self.batch_pending_peak = max(self.batch_pending_peak, len(self.pending_tasks))
self.batch_condition.notify_all()
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.output is not None
return task.output, dict(task.profile)
def snapshot(self) -> Dict[str, float | int | bool]:
with self.batch_condition:
average_tasks_per_batch = (
float(self.batch_total_tasks) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
)
average_rows_per_batch = (
float(self.batch_total_rows) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
)
average_queue_wait_ms = (
float(self.batch_total_queue_wait_ms) / float(self.batch_total_tasks) if self.batch_total_tasks > 0 else 0.0
)
average_collect_wait_ms = (
float(self.batch_total_collect_wait_ms) / float(self.batch_total_tasks)
if self.batch_total_tasks > 0
else 0.0
)
return {
"shard_index": int(self.shard_index),
"enabled": bool(self.batch_enabled),
"enable_cuda_graph": bool(self.enable_cuda_graph),
"enable_profiling": bool(self.enable_profiling),
"full_graph_cache_limit": int(self.full_graph_cache_limit),
"tail_graph_cache_limit": int(self.tail_graph_cache_limit),
"window_ms": float(self.batch_window_s * 1000.0),
"max_requests": int(self.batch_max_requests),
"max_rows": int(self.batch_max_rows),
"max_tokens": int(self.batch_max_tokens),
"pending": int(len(self.pending_tasks)),
"pending_peak": int(self.batch_pending_peak),
"total_batches": int(self.batch_total_batches),
"total_tasks": int(self.batch_total_tasks),
"total_rows": int(self.batch_total_rows),
"avg_tasks_per_batch": float(average_tasks_per_batch),
"avg_rows_per_batch": float(average_rows_per_batch),
"avg_queue_wait_ms": float(average_queue_wait_ms),
"queue_wait_peak_ms": float(self.batch_queue_wait_peak_ms),
"avg_collect_wait_ms": float(average_collect_wait_ms),
"collect_wait_peak_ms": float(self.batch_collect_wait_peak_ms),
"run_total_ms": float(self.batch_total_run_ms),
"run_peak_ms": float(self.batch_run_peak_ms),
"batch_rows_peak": int(self.batch_rows_peak),
"batch_requests_peak": int(self.batch_requests_peak),
}
def pending_rows(self) -> int:
with self.batch_condition:
return int(sum(int(task.model_input["char_ids"].shape[0]) for task in self.pending_tasks))
def pending_count(self) -> int:
with self.batch_condition:
return int(len(self.pending_tasks))
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
if not self.batch_enabled:
started = time.perf_counter()
output = self._run_direct(model_input)
return output, {
"g2pw_runtime_queue_wait_ms": 0.0,
"g2pw_runtime_collect_wait_ms": 0.0,
"g2pw_runtime_run_ms": float((time.perf_counter() - started) * 1000.0),
"g2pw_runtime_batch_rows": float(model_input["char_ids"].shape[0]),
"g2pw_runtime_batch_requests": 1.0,
"g2pw_runtime_task_rows": float(model_input["char_ids"].shape[0]),
"g2pw_runtime_task_seq_len": float(model_input["input_ids"].shape[1]),
"g2pw_runtime_shard_index": float(self.shard_index),
}
return self._submit_batched(model_input)
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
output, _profile = self.run_with_profile(model_input)
return output
class G2PWRuntimePool:
def __init__(self) -> None:
self.worker_count = max(1, _env_int("GPTSOVITS_G2PW_CUDA_WORKERS", 2))
self.shards = [G2PWRuntimeWrapper(shard_index=index) for index in range(self.worker_count)]
self.lock = threading.Lock()
def _pick_shard(self) -> G2PWRuntimeWrapper:
with self.lock:
return min(
self.shards,
key=lambda shard: (
shard.pending_rows(),
shard.pending_count(),
shard.snapshot().get("avg_queue_wait_ms", 0.0),
),
)
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
shard = self._pick_shard()
output, profile = shard.run_with_profile(model_input)
profile["g2pw_runtime_pool_workers"] = float(self.worker_count)
return output, profile
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
output, _profile = self.run_with_profile(model_input)
return output
def snapshot(self) -> Dict[str, float | int | bool | List[Dict[str, float | int | bool]]]:
shard_snapshots = [dict(shard.snapshot()) for shard in self.shards]
avg_queue_wait_ms = 0.0
total_tasks = 0.0
pending = 0
pending_peak = 0
total_batches = 0
total_rows = 0
batch_rows_peak = 0
batch_requests_peak = 0
for snapshot in shard_snapshots:
tasks = float(snapshot.get("total_tasks", 0.0))
avg_queue_wait_ms += float(snapshot.get("avg_queue_wait_ms", 0.0)) * tasks
total_tasks += tasks
pending += int(snapshot.get("pending", 0))
pending_peak = max(pending_peak, int(snapshot.get("pending_peak", 0)))
total_batches += int(snapshot.get("total_batches", 0))
total_rows += int(snapshot.get("total_rows", 0))
batch_rows_peak = max(batch_rows_peak, int(snapshot.get("batch_rows_peak", 0)))
batch_requests_peak = max(batch_requests_peak, int(snapshot.get("batch_requests_peak", 0)))
return {
"worker_count": int(self.worker_count),
"pending": int(pending),
"pending_peak": int(pending_peak),
"total_batches": int(total_batches),
"total_tasks": int(total_tasks),
"total_rows": int(total_rows),
"avg_queue_wait_ms": float(avg_queue_wait_ms / total_tasks) if total_tasks > 0 else 0.0,
"batch_rows_peak": int(batch_rows_peak),
"batch_requests_peak": int(batch_requests_peak),
"shards": shard_snapshots,
}
class G2PWCudaConverter(_G2PWBaseOnnxConverter):
def __init__(
self,
model_dir: str = "G2PWModel/",
style: str = "bopomofo",
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
super().__init__(
model_dir=model_dir,
style=style,
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
self.runtime = G2PWRuntimePool()
self.backend = "cuda"
primary_runtime = self.runtime.shards[0]
self.device = f"cuda:{primary_runtime.device_ordinal}"
self.checkpoint_path = str(primary_runtime.weights_path)
self.providers = ["g2pw-cu"]
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
probs = self.runtime.run(model_input)
preds = np.argmax(probs, axis=1).tolist()
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
return [self.labels[pred] for pred in preds], confidences
def _predict_with_profile(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float], Dict[str, float]]:
started = time.perf_counter()
probs, runtime_profile = self.runtime.run_with_profile(model_input)
preds = np.argmax(probs, axis=1).tolist()
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
profile = dict(runtime_profile)
profile["g2pw_runtime_total_ms"] = float((time.perf_counter() - started) * 1000.0)
profile["g2pw_predict_ms"] = float(profile["g2pw_runtime_total_ms"])
return [self.labels[pred] for pred in preds], confidences, profile
def snapshot(self) -> Dict[str, float | int | bool]:
return dict(self.runtime.snapshot())

View File

@ -18,6 +18,7 @@ Credits
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
@ -37,6 +38,8 @@ def prepare_onnx_input(
use_mask: bool = False,
window_size: int = None,
max_len: int = 512,
char2id: Optional[Dict[str, int]] = None,
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
@ -48,33 +51,88 @@ def prepare_onnx_input(
phoneme_masks = []
char_ids = []
position_ids = []
tokenized_cache = {}
if char2id is None:
char2id = {char: idx for idx, char in enumerate(chars)}
if use_mask:
if char_phoneme_masks is None:
char_phoneme_masks = {
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
for char in char2phonemes
}
else:
full_phoneme_mask = [1] * len(labels)
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
cached = tokenized_cache.get(text)
if cached is None:
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
)
if len(tokens) <= max_len - 2:
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
cached = {
"is_short": True,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
"input_id": shared_input_id,
"token_type_id": shared_token_type_id,
"attention_mask": shared_attention_mask,
}
else:
cached = {
"is_short": False,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
}
tokenized_cache[text] = cached
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
if cached["is_short"]:
text_for_query = text
query_id_for_query = query_id
text2token_for_query = cached["text2token"]
input_id = cached["input_id"]
token_type_id = cached["token_type_id"]
attention_mask = cached["attention_mask"]
else:
(
text_for_query,
query_id_for_query,
tokens_for_query,
text2token_for_query,
_token2text_for_query,
) = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=cached["tokens"],
text2token=cached["text2token"],
token2text=cached["token2text"],
)
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = (
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
)
char_id = chars.index(query_char)
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
query_char = text_for_query[query_id_for_query]
if use_mask:
phoneme_mask = char_phoneme_masks[query_char]
else:
phoneme_mask = full_phoneme_mask
char_id = char2id[query_char]
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
@ -83,10 +141,15 @@ def prepare_onnx_input(
char_ids.append(char_id)
position_ids.append(position_id)
max_token_length = max(len(seq) for seq in input_ids)
def _pad_sequences(sequences, pad_value=0):
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
outputs = {
"input_ids": np.array(input_ids).astype(np.int64),
"token_type_ids": np.array(token_type_ids).astype(np.int64),
"attention_masks": np.array(attention_masks).astype(np.int64),
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
"char_ids": np.array(char_ids).astype(np.int64),
"position_ids": np.array(position_ids).astype(np.int64),

View File

@ -8,6 +8,7 @@ from pypinyin.core import Pinyin, Style
from pypinyin.seg.simpleseg import simple_seg
from pypinyin.converter import UltimateConverter
from pypinyin.contrib.tone_convert import to_tone
from .cuda_api import G2PWCudaConverter
from .onnx_api import G2PWOnnxConverter
current_file_path = os.path.dirname(__file__)
@ -27,12 +28,36 @@ class G2PWPinyin(Pinyin):
tone_sandhi=False,
**kwargs,
):
self._g2pw = G2PWOnnxConverter(
model_dir=model_dir,
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
backend = os.environ.get("GPTSOVITS_G2PW_BACKEND", "cuda").strip().lower()
last_error = None
self._g2pw = None
if backend in {"cuda", "auto"}:
try:
self._g2pw = G2PWCudaConverter(
model_dir=model_dir,
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
except Exception as exc:
last_error = exc
strict_mode = os.environ.get("GPTSOVITS_G2PW_CUDA_STRICT", "0").strip().lower() in {
"1",
"true",
"yes",
"on",
}
if backend == "cuda" and strict_mode:
raise
if self._g2pw is None:
self._g2pw = G2PWOnnxConverter(
model_dir=model_dir,
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
if last_error is not None:
print(f"[g2pw] cuda backend unavailable, fallback to onnx: {last_error}")
self._converter = Converter(
self._g2pw,
v_to_u=v_to_u,

View File

@ -0,0 +1,183 @@
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include "g2pw/runtime.h"
namespace {
struct G2PWRuntimeHandle {
std::unique_ptr<g2pw::Runtime> runtime;
std::string last_error;
int num_labels = 0;
};
void SetError(G2PWRuntimeHandle* handle, const g2pw::Status& status) {
if (handle == nullptr) {
return;
}
handle->last_error = status.message;
}
g2pw::RuntimeConfig BuildConfig(
int device_ordinal,
int max_batch_size,
int max_seq_len,
int full_graph_cache_limit,
int tail_graph_cache_limit,
int allow_tensor_cores,
int use_cublaslt_bias_epilogue,
int enable_profiling,
int enable_cuda_graph,
int dump_graph_cache_stats,
int gemm_precision) {
g2pw::RuntimeConfig config{};
config.device_ordinal = device_ordinal;
config.max_batch_size = max_batch_size;
config.max_seq_len = max_seq_len;
config.full_graph_cache_limit = full_graph_cache_limit;
config.tail_graph_cache_limit = tail_graph_cache_limit;
config.allow_tensor_cores = allow_tensor_cores != 0;
config.use_cublaslt_bias_epilogue = use_cublaslt_bias_epilogue != 0;
config.enable_profiling = enable_profiling != 0;
config.enable_cuda_graph = enable_cuda_graph != 0;
config.dump_graph_cache_stats = dump_graph_cache_stats != 0;
switch (gemm_precision) {
case 1:
config.gemm_precision = g2pw::GemmPrecision::kFp16;
break;
case 2:
config.gemm_precision = g2pw::GemmPrecision::kBf16;
break;
default:
config.gemm_precision = g2pw::GemmPrecision::kFp32;
break;
}
return config;
}
} // namespace
extern "C" {
void* g2pw_runtime_create(
const char* manifest_path,
const char* binary_path,
int device_ordinal,
int max_batch_size,
int max_seq_len,
int full_graph_cache_limit,
int tail_graph_cache_limit,
int allow_tensor_cores,
int use_cublaslt_bias_epilogue,
int enable_profiling,
int enable_cuda_graph,
int dump_graph_cache_stats,
int gemm_precision) {
auto* handle = new G2PWRuntimeHandle();
try {
if (manifest_path == nullptr || binary_path == nullptr) {
handle->last_error = "manifest_path and binary_path must be non-null";
return handle;
}
g2pw::RuntimeConfig config = BuildConfig(
device_ordinal,
max_batch_size,
max_seq_len,
full_graph_cache_limit,
tail_graph_cache_limit,
allow_tensor_cores,
use_cublaslt_bias_epilogue,
enable_profiling,
enable_cuda_graph,
dump_graph_cache_stats,
gemm_precision);
g2pw::Status status = g2pw::Runtime::Create(
config,
std::string(manifest_path),
std::string(binary_path),
&handle->runtime);
if (!status.ok()) {
SetError(handle, status);
return handle;
}
handle->num_labels = handle->runtime != nullptr ? handle->runtime->weights().manifest().num_labels : 0;
handle->last_error.clear();
return handle;
} catch (const std::exception& exc) {
handle->last_error = exc.what();
return handle;
} catch (...) {
handle->last_error = "unknown exception";
return handle;
}
}
void g2pw_runtime_destroy(void* raw_handle) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
delete handle;
}
const char* g2pw_runtime_last_error(void* raw_handle) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
if (handle == nullptr) {
return "invalid runtime handle";
}
return handle->last_error.c_str();
}
int g2pw_runtime_num_labels(void* raw_handle) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
if (handle == nullptr || handle->runtime == nullptr) {
return 0;
}
return handle->num_labels;
}
int g2pw_runtime_run(
void* raw_handle,
const std::int64_t* input_ids,
const std::int64_t* token_type_ids,
const std::int64_t* attention_mask,
const float* phoneme_mask,
const std::int64_t* char_ids,
const std::int64_t* position_ids,
std::int32_t batch_size,
std::int32_t seq_len,
float* probs) {
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
if (handle == nullptr || handle->runtime == nullptr) {
return static_cast<int>(g2pw::StatusCode::kInvalidArgument);
}
try {
g2pw::InferenceInputs inputs{};
inputs.input_ids = input_ids;
inputs.token_type_ids = token_type_ids;
inputs.attention_mask = attention_mask;
inputs.phoneme_mask = phoneme_mask;
inputs.char_ids = char_ids;
inputs.position_ids = position_ids;
inputs.batch_size = batch_size;
inputs.seq_len = seq_len;
g2pw::InferenceOutputs outputs{};
outputs.probs = probs;
const g2pw::Status status = handle->runtime->Run(inputs, outputs);
if (!status.ok()) {
SetError(handle, status);
return static_cast<int>(status.code);
}
handle->last_error.clear();
return static_cast<int>(g2pw::StatusCode::kOk);
} catch (const std::exception& exc) {
handle->last_error = exc.what();
return static_cast<int>(g2pw::StatusCode::kInternalError);
} catch (...) {
handle->last_error = "unknown exception";
return static_cast<int>(g2pw::StatusCode::kInternalError);
}
}
}

View File

@ -3,6 +3,7 @@
import json
import os
import time
import warnings
import zipfile
from typing import Any, Dict, List, Tuple
@ -10,7 +11,6 @@ from typing import Any, Dict, List, Tuple
import numpy as np
import onnxruntime
import requests
import torch
from opencc import OpenCC
from pypinyin import Style, pinyin
from transformers.models.auto.tokenization_auto import AutoTokenizer
@ -22,9 +22,8 @@ from .utils import load_config
onnxruntime.set_default_logger_severity(3)
try:
onnxruntime.preload_dlls()
except:
except Exception:
pass
# traceback.print_exc()
warnings.filterwarnings("ignore")
model_version = "1.1"
@ -55,6 +54,41 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis
return all_preds, all_confidences
def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]:
for candidate_dir in candidate_dirs:
if not candidate_dir:
continue
json_path = os.path.join(candidate_dir, filename)
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as fr:
return json.load(fr)
raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}")
def _find_first_existing_file(*paths: str) -> str:
for path in paths:
if path and os.path.exists(path):
return path
raise FileNotFoundError(f"Files not found: {paths}")
def _resolve_tokenizer_source(model_source: str | None) -> str:
candidate_paths = []
if model_source:
candidate_paths.append(model_source)
repo_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
candidate_paths.extend(
[
os.path.join(repo_root, "pretrained_models", "g2pw-chinese"),
os.path.join(repo_root, "pretrained_models", "chinese-roberta-wwm-ext-large"),
]
)
for candidate in candidate_paths:
if candidate and os.path.exists(candidate):
return candidate
return model_source or "bert-base-chinese"
def download_and_decompress(model_dir: str = "G2PWModel/"):
if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir)
@ -62,7 +96,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
print("Downloading g2pw model...")
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"
with requests.get(modelscope_url, stream=True) as r:
r.raise_for_status()
with open(zip_dir, "wb") as f:
@ -79,7 +113,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
return model_dir
class G2PWOnnxConverter:
class _G2PWBaseOnnxConverter:
def __init__(
self,
model_dir: str = "G2PWModel/",
@ -87,33 +121,16 @@ class G2PWOnnxConverter:
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
uncompress_path = download_and_decompress(model_dir)
self.model_dir = download_and_decompress(model_dir)
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
else:
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.model_source = _resolve_tokenizer_source(model_source if model_source else self.config.model_source)
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source, local_files_only=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
self.polyphonic_chars = [
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
]
@ -149,31 +166,47 @@ class G2PWOnnxConverter:
)
self.chars = sorted(list(self.char2phonemes.keys()))
self.char2id = {char: idx for idx, char in enumerate(self.chars)}
self.char_phoneme_masks = (
{
char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))]
for char in self.char2phonemes
}
if self.config.use_mask
else None
)
self.polyphonic_chars_new = set(self.chars)
for char in self.non_polyphonic:
if char in self.polyphonic_chars_new:
self.polyphonic_chars_new.remove(char)
self.polyphonic_chars_new.discard(char)
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
for char in self.non_monophonic:
if char in self.monophonic_chars_dict:
self.monophonic_chars_dict.pop(char)
self.monophonic_chars_dict.pop(char, None)
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel"))
candidate_asset_dirs = [self.model_dir, default_asset_dir]
self.bopomofo_convert_dict = _load_json_from_candidates(
"bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs
)
self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs)
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
"bopomofo": lambda x: x,
"pinyin": self._convert_bopomofo_to_pinyin,
}[style]
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC("s2tw")
self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in {
"1",
"true",
"yes",
"y",
"on",
}
# 聚焦到多音字附近上下文默认左右各16字设为0表示关闭裁剪整句
self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16")))
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
@ -181,11 +214,14 @@ class G2PWOnnxConverter:
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
else:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
def __call__(self, sentences: List[str]) -> List[List[str]]:
results, _profile = self.predict_sentences_with_profile(sentences)
return results
def predict_sentences_with_profile(self, sentences: List[str]) -> Tuple[List[List[str]], Dict[str, float]]:
if isinstance(sentences, str):
sentences = [sentences]
@ -197,51 +233,202 @@ class G2PWOnnxConverter:
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
return partial_results, {}
onnx_input = prepare_onnx_input(
model_input = prepare_onnx_input(
tokenizer=self.tokenizer,
labels=self.labels,
char2phonemes=self.char2phonemes,
chars=self.chars,
texts=texts,
query_ids=query_ids,
query_ids=model_query_ids,
use_mask=self.config.use_mask,
window_size=None,
char2id=self.char2id,
char_phoneme_masks=self.char_phoneme_masks,
)
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
if not model_input:
return partial_results, {}
predict_profile: Dict[str, float] = {}
if self.enable_sentence_dedup:
preds, _confidences, predict_profile = self._predict_with_sentence_dedup_profiled(
model_input=model_input,
texts=texts,
)
else:
if hasattr(self, "_predict_with_profile"):
preds, _confidences, predict_profile = self._predict_with_profile(model_input=model_input)
else:
predict_started = time.perf_counter()
preds, _confidences = self._predict(model_input=model_input)
predict_profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_started) * 1000.0)
if self.config.use_char_phoneme:
preds = [pred.split(" ")[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
results[sent_id][query_id] = self.style_convert_func(pred)
return results
return results, predict_profile
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]:
texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent)
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent)
polyphonic_indices: List[int] = []
for i, char in enumerate(sent):
if char in self.polyphonic_chars_new:
texts.append(sent)
query_ids.append(i)
sent_ids.append(sent_id)
polyphonic_indices.append(i)
elif char in self.monophonic_chars_dict:
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
else:
partial_result[i] = pypinyin_result[i][0]
if polyphonic_indices:
if self.polyphonic_context_chars > 0:
left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars)
right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1)
sent_for_predict = sent[left:right]
query_offset = left
else:
sent_for_predict = sent
query_offset = 0
for index in polyphonic_indices:
texts.append(sent_for_predict)
model_query_ids.append(index - query_offset)
result_query_ids.append(index)
sent_ids.append(sent_id)
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results
return texts, model_query_ids, result_query_ids, sent_ids, partial_results
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
raise NotImplementedError
def _predict_with_sentence_dedup(
self, model_input: Dict[str, Any], texts: List[str]
) -> Tuple[List[str], List[float]]:
if len(texts) <= 1:
return self._predict(model_input=model_input)
grouped_indices: Dict[str, List[int]] = {}
for idx, text in enumerate(texts):
grouped_indices.setdefault(text, []).append(idx)
if all(len(indices) == 1 for indices in grouped_indices.values()):
return self._predict(model_input=model_input)
preds: List[str] = [""] * len(texts)
confidences: List[float] = [0.0] * len(texts)
for indices in grouped_indices.values():
group_input = {name: value[indices] for name, value in model_input.items()}
if len(indices) > 1:
for name in ("input_ids", "token_type_ids", "attention_masks"):
group_input[name] = group_input[name][:1]
group_preds, group_confidences = self._predict(model_input=group_input)
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
preds[output_idx] = pred
confidences[output_idx] = confidence
return preds, confidences
def _predict_with_sentence_dedup_profiled(
self,
model_input: Dict[str, Any],
texts: List[str],
) -> Tuple[List[str], List[float], Dict[str, float]]:
if len(texts) <= 1:
if hasattr(self, "_predict_with_profile"):
return self._predict_with_profile(model_input=model_input)
predict_started = time.perf_counter()
preds, confidences = self._predict(model_input=model_input)
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
grouped_indices: Dict[str, List[int]] = {}
for idx, text in enumerate(texts):
grouped_indices.setdefault(text, []).append(idx)
if all(len(indices) == 1 for indices in grouped_indices.values()):
if hasattr(self, "_predict_with_profile"):
return self._predict_with_profile(model_input=model_input)
predict_started = time.perf_counter()
preds, confidences = self._predict(model_input=model_input)
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
preds: List[str] = [""] * len(texts)
confidences: List[float] = [0.0] * len(texts)
merged_profile: Dict[str, float] = {}
for indices in grouped_indices.values():
group_input = {name: value[indices] for name, value in model_input.items()}
if len(indices) > 1:
for name in ("input_ids", "token_type_ids", "attention_masks"):
group_input[name] = group_input[name][:1]
if hasattr(self, "_predict_with_profile"):
group_preds, group_confidences, group_profile = self._predict_with_profile(model_input=group_input)
for key, value in dict(group_profile or {}).items():
merged_profile[key] = float(merged_profile.get(key, 0.0)) + float(value)
else:
predict_started = time.perf_counter()
group_preds, group_confidences = self._predict(model_input=group_input)
merged_profile["g2pw_predict_ms"] = float(
merged_profile.get("g2pw_predict_ms", 0.0) + (time.perf_counter() - predict_started) * 1000.0
)
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
preds[output_idx] = pred
confidences[output_idx] = confidence
return preds, confidences, merged_profile
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
def __init__(
self,
model_dir: str = "G2PWModel/",
style: str = "bopomofo",
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
super().__init__(
model_dir=model_dir,
style=style,
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
onnx_path = _find_first_existing_file(
os.path.join(self.model_dir, "g2pW.onnx"),
os.path.join(self.model_dir, "g2pw.onnx"),
)
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pw = onnxruntime.InferenceSession(
onnx_path,
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
else:
self.session_g2pw = onnxruntime.InferenceSession(
onnx_path,
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels)

285
api_v2.py
View File

@ -39,8 +39,8 @@ POST:
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
"super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights`
GET:
```
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
```
RESP:
成功: 返回"success", http code 200
@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights`
GET:
```
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
```
RESP:
@ -104,27 +104,22 @@ RESP:
import os
import sys
import traceback
from typing import Generator, Union
from typing import Union
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import argparse
import subprocess
import wave
import signal
import numpy as np
import soundfile as sf
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine
from pydantic import BaseModel
import threading
# print(sys.path)
i18n = I18nAuto()
@ -147,6 +142,14 @@ if config_path in [None, ""]:
tts_config = TTS_Config(config_path)
print(tts_config)
tts_pipeline = TTS(tts_config)
tts_engine = UnifiedTTSEngine(
tts_pipeline,
cut_method_names=cut_method_names,
control_callbacks=RuntimeControlCallbacks(
restart=lambda: os.execl(sys.executable, sys.executable, *argv),
exit=lambda: os.kill(os.getpid(), signal.SIGTERM),
),
)
APP = FastAPI()
@ -178,168 +181,8 @@ class TTS_Request(BaseModel):
min_chunk_length: int = 16
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
# Author: AkagawaTsurunaki
# Issue:
# Stack overflow probabilistically occurs
# when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
# using the Python library `soundfile`
# Note:
# This is an issue related to `libsndfile`, not this project itself.
# It happens when you generate a large audio tensor (about 499804 frames in my PC)
# and try to convert it to an ogg file.
# Related:
# https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
# https://github.com/libsndfile/libsndfile/issues/1023
# https://github.com/bastibe/python-soundfile/issues/396
# Suggestion:
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
def handle_pack_ogg():
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
audio_file.write(data)
# See: https://docs.python.org/3/library/threading.html
# The stack size of this thread is at least 32768
# If stack overflow error still occurs, just modify the `stack_size`.
# stack_size = n * 4096, where n should be a positive integer.
# Here we chose n = 4096.
stack_size = 4096 * 4096
try:
threading.stack_size(stack_size)
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
pack_ogg_thread.start()
pack_ogg_thread.join()
except RuntimeError as e:
# If changing the thread stack size is unsupported, a RuntimeError is raised.
print("RuntimeError: {}".format(e))
print("Changing the thread stack size is unsupported.")
except ValueError as e:
# If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
print("ValueError: {}".format(e))
print("The specified stack size is invalid.")
return io_buffer
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer.write(data.tobytes())
return io_buffer
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
io_buffer = BytesIO()
sf.write(io_buffer, data, rate, format="wav")
return io_buffer
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
process = subprocess.Popen(
[
"ffmpeg",
"-f",
"s16le", # 输入16位有符号小端整数PCM
"-ar",
str(rate), # 设置采样率
"-ac",
"1", # 单声道
"-i",
"pipe:0", # 从管道读取输入
"-c:a",
"aac", # 音频编码器为AAC
"-b:a",
"192k", # 比特率
"-vn", # 不包含视频
"-f",
"adts", # 输出AAC数据流格式
"pipe:1", # 将输出写入管道
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, _ = process.communicate(input=data.tobytes())
io_buffer.write(out)
return io_buffer
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
if media_type == "ogg":
io_buffer = pack_ogg(io_buffer, data, rate)
elif media_type == "aac":
io_buffer = pack_aac(io_buffer, data, rate)
elif media_type == "wav":
io_buffer = pack_wav(io_buffer, data, rate)
else:
io_buffer = pack_raw(io_buffer, data, rate)
io_buffer.seek(0)
return io_buffer
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
# This will create a wave header then append the frame input
# It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
return wav_buf.read()
def handle_control(command: str):
if command == "restart":
os.execl(sys.executable, sys.executable, *argv)
elif command == "exit":
os.kill(os.getpid(), signal.SIGTERM)
exit(0)
def check_params(req: dict):
text: str = req.get("text", "")
text_lang: str = req.get("text_lang", "")
ref_audio_path: str = req.get("ref_audio_path", "")
streaming_mode: bool = req.get("streaming_mode", False)
media_type: str = req.get("media_type", "wav")
prompt_lang: str = req.get("prompt_lang", "")
text_split_method: str = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if text_lang in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
elif text_lang.lower() not in tts_config.languages:
return JSONResponse(
status_code=400,
content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"},
)
if prompt_lang in [None, ""]:
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages:
return JSONResponse(
status_code=400,
content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"},
)
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
# elif media_type == "ogg" and not streaming_mode:
# return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(
status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}
)
return None
def _lower_or_none(value: str | None) -> str | None:
return value.lower() if isinstance(value, str) else value
async def tts_handle(req: dict):
@ -368,7 +211,7 @@ async def tts_handle(req: dict):
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
"super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
@ -377,70 +220,11 @@ async def tts_handle(req: dict):
StreamingResponse: audio stream response.
"""
streaming_mode = req.get("streaming_mode", False)
return_fragment = req.get("return_fragment", False)
media_type = req.get("media_type", "wav")
check_res = check_params(req)
if check_res is not None:
return check_res
if streaming_mode == 0:
streaming_mode = False
return_fragment = False
fixed_length_chunk = False
elif streaming_mode == 1:
streaming_mode = False
return_fragment = True
fixed_length_chunk = False
elif streaming_mode == 2:
streaming_mode = True
return_fragment = False
fixed_length_chunk = False
elif streaming_mode == 3:
streaming_mode = True
return_fragment = False
fixed_length_chunk = True
else:
return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"})
req["streaming_mode"] = streaming_mode
req["return_fragment"] = return_fragment
req["fixed_length_chunk"] = fixed_length_chunk
print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}")
streaming_mode = streaming_mode or return_fragment
try:
tts_generator = tts_pipeline.run(req)
if streaming_mode:
def streaming_generator(tts_generator: Generator, media_type: str):
if_frist_chunk = True
for sr, chunk in tts_generator:
if if_frist_chunk and media_type == "wav":
yield wave_header_chunk(sample_rate=sr)
media_type = "raw"
if_frist_chunk = False
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(
streaming_generator(
tts_generator,
media_type,
),
media_type=f"audio/{media_type}",
)
else:
sr, audio_data = next(tts_generator)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return Response(audio_data, media_type=f"audio/{media_type}")
result = await tts_engine.run_direct_tts_async(req)
if result.streaming:
return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}")
return Response(result.audio_bytes, media_type=f"audio/{result.media_type}")
except Exception as e:
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
@ -449,7 +233,11 @@ async def tts_handle(req: dict):
async def control(command: str = None):
if command is None:
return JSONResponse(status_code=400, content={"message": "command is required"})
handle_control(command)
try:
tts_engine.handle_control(command)
return JSONResponse(status_code=200, content={"message": "success"})
except Exception as e:
return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)})
@APP.get("/tts")
@ -481,11 +269,11 @@ async def tts_get_endpoint(
):
req = {
"text": text,
"text_lang": text_lang.lower(),
"text_lang": _lower_or_none(text_lang),
"ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": aux_ref_audio_paths,
"prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(),
"prompt_lang": _lower_or_none(prompt_lang),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
@ -517,10 +305,10 @@ async def tts_post_endpoint(request: TTS_Request):
@APP.get("/set_refer_audio")
async def set_refer_aduio(refer_audio_path: str = None):
try:
tts_pipeline.set_ref_audio(refer_audio_path)
payload = tts_engine.set_refer_audio(refer_audio_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
return JSONResponse(status_code=200, content=payload)
# @APP.post("/set_refer_audio")
@ -545,24 +333,19 @@ async def set_refer_aduio(refer_audio_path: str = None):
@APP.get("/set_gpt_weights")
async def set_gpt_weights(weights_path: str = None):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
tts_pipeline.init_t2s_weights(weights_path)
payload = tts_engine.set_gpt_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
return JSONResponse(status_code=200, content=payload)
@APP.get("/set_sovits_weights")
async def set_sovits_weights(weights_path: str = None):
try:
if weights_path in ["", None]:
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
tts_pipeline.init_vits_weights(weights_path)
payload = tts_engine.set_sovits_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content={"message": "success"})
return JSONResponse(status_code=200, content=payload)
if __name__ == "__main__":

443
api_v3.py Normal file
View File

@ -0,0 +1,443 @@
"""
# WebAPI文档
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
## 执行参数:
`-a` - `绑定地址, 默认"127.0.0.1"`
`-p` - `绑定端口, 默认9880`
`-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
## 调用:
### 推理
endpoint: `/tts`
GET:
```
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是罗浮云骑将军景元不必拘谨将军只是一时的身份你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
```
POST:
```json
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 15, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
"super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
}
```
RESP:
成功: 直接返回 wav 音频流 http code 200
失败: 返回包含错误信息的 json, http code 400
### 命令控制
endpoint: `/control`
command:
"restart": 重新运行
"exit": 结束运行
GET:
```
http://127.0.0.1:9880/control?command=restart
```
POST:
```json
{
"command": "restart"
}
```
RESP:
### 切换GPT模型
endpoint: `/set_gpt_weights`
GET:
```
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
### 切换Sovits模型
endpoint: `/set_sovits_weights`
GET:
```
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
```
RESP:
成功: 返回"success", http code 200
失败: 返回包含错误信息的 json, http code 400
"""
import os
import sys
import traceback
from typing import List, Union
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
from runtime_preload import preload_text_runtime_deps
preload_text_runtime_deps()
import argparse
import signal
from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from pydantic import BaseModel
# print(sys.path)
i18n = I18nAuto()
cut_method_names = get_cut_method_names()
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
args = parser.parse_args()
config_path = args.tts_config
# device = args.device
port = args.port
host = args.bind_addr
argv = sys.argv
if config_path in [None, ""]:
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
tts_config = TTS_Config(config_path)
print(tts_config)
tts_pipeline = TTS(tts_config)
tts_engine = UnifiedTTSEngine(
tts_pipeline,
cut_method_names=cut_method_names,
control_callbacks=RuntimeControlCallbacks(
restart=lambda: os.execl(sys.executable, sys.executable, *argv),
exit=lambda: os.kill(os.getpid(), signal.SIGTERM),
),
)
APP = FastAPI()
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
ref_audio_path: str = None
aux_ref_audio_paths: list = None
prompt_lang: str = None
prompt_text: str = ""
top_k: int = 15
top_p: float = 1
temperature: float = 1
text_split_method: str = "cut5"
batch_size: int = 1
batch_threshold: float = 0.75
split_bucket: bool = True
speed_factor: float = 1.0
fragment_interval: float = 0.3
seed: int = -1
media_type: str = "wav"
streaming_mode: Union[bool, int] = False
parallel_infer: bool = True
repetition_penalty: float = 1.35
sample_steps: int = 32
super_sampling: bool = False
overlap_length: int = 2
min_chunk_length: int = 16
class Scheduler_Debug_Request_Item(BaseModel):
request_id: str | None = None
text: str
text_lang: str
ref_audio_path: str
prompt_lang: str
prompt_text: str = ""
top_k: int = 15
top_p: float = 1
temperature: float = 1
repetition_penalty: float = 1.35
early_stop_num: int = -1
ready_step: int = 0
class Scheduler_Debug_Request(BaseModel):
requests: List[Scheduler_Debug_Request_Item]
max_steps: int = 1500
seed: int = -1
class Scheduler_Submit_Request(BaseModel):
request_id: str | None = None
text: str
text_lang: str
ref_audio_path: str
prompt_lang: str
prompt_text: str = ""
top_k: int = 15
top_p: float = 1
temperature: float = 1
repetition_penalty: float = 1.35
early_stop_num: int = -1
speed_factor: float = 1.0
sample_steps: int = 32
media_type: str = "wav"
timeout_sec: float = 30.0
def _lower_or_none(value: str | None) -> str | None:
return value.lower() if isinstance(value, str) else value
async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
try:
result = await tts_engine.run_scheduler_debug(
request_items=[item.dict() for item in request.requests],
max_steps=int(request.max_steps),
seed=int(request.seed),
)
return JSONResponse(status_code=200, content=result.payload)
except Exception as e:
return JSONResponse(
status_code=400,
content={"message": "scheduler debug failed", "Exception": str(e)},
)
async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
try:
result = await tts_engine.run_scheduler_submit(request.dict())
return Response(result.audio_bytes, media_type=result.media_type, headers=result.headers)
except Exception as e:
return JSONResponse(
status_code=400,
content={"message": "scheduler submit failed", "Exception": str(e)},
)
async def tts_handle(req: dict):
"""
Text to speech handler.
Args:
req (dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 15, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
}
returns:
StreamingResponse: audio stream response.
"""
try:
result = await tts_engine.run_direct_tts_async(req)
if result.streaming:
return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}")
return Response(result.audio_bytes, media_type=f"audio/{result.media_type}")
except Exception as e:
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
@APP.get("/control")
async def control(command: str = None):
if command is None:
return JSONResponse(status_code=400, content={"message": "command is required"})
try:
tts_engine.handle_control(command)
return JSONResponse(status_code=200, content={"message": "success"})
except Exception as e:
return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)})
@APP.get("/tts")
async def tts_get_endpoint(
text: str = None,
text_lang: str = None,
ref_audio_path: str = None,
aux_ref_audio_paths: list = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k: int = 15,
top_p: float = 1,
temperature: float = 1,
text_split_method: str = "cut5",
batch_size: int = 1,
batch_threshold: float = 0.75,
split_bucket: bool = True,
speed_factor: float = 1.0,
fragment_interval: float = 0.3,
seed: int = -1,
media_type: str = "wav",
parallel_infer: bool = True,
repetition_penalty: float = 1.35,
sample_steps: int = 32,
super_sampling: bool = False,
streaming_mode: Union[bool, int] = False,
overlap_length: int = 2,
min_chunk_length: int = 16,
):
req = {
"text": text,
"text_lang": _lower_or_none(text_lang),
"ref_audio_path": ref_audio_path,
"aux_ref_audio_paths": aux_ref_audio_paths,
"prompt_text": prompt_text,
"prompt_lang": _lower_or_none(prompt_lang),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": text_split_method,
"batch_size": int(batch_size),
"batch_threshold": float(batch_threshold),
"speed_factor": float(speed_factor),
"split_bucket": split_bucket,
"fragment_interval": fragment_interval,
"seed": seed,
"media_type": media_type,
"streaming_mode": streaming_mode,
"parallel_infer": parallel_infer,
"repetition_penalty": float(repetition_penalty),
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
"overlap_length": int(overlap_length),
"min_chunk_length": int(min_chunk_length),
}
return await tts_handle(req)
@APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request):
req = request.dict()
return await tts_handle(req)
@APP.post("/tts_scheduler_debug")
async def tts_scheduler_debug_endpoint(request: Scheduler_Debug_Request):
return await tts_scheduler_debug_handle(request)
@APP.post("/tts_scheduler_submit")
async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request):
return await tts_scheduler_submit_handle(request)
@APP.get("/tts_scheduler_state")
async def tts_scheduler_state_endpoint():
return JSONResponse(status_code=200, content=tts_engine.get_runtime_state())
@APP.get("/set_refer_audio")
async def set_refer_aduio(refer_audio_path: str = None):
try:
payload = tts_engine.set_refer_audio(refer_audio_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
return JSONResponse(status_code=200, content=payload)
# @APP.post("/set_refer_audio")
# async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
# try:
# # 检查文件类型,确保是音频文件
# if not audio_file.content_type.startswith("audio/"):
# return JSONResponse(status_code=400, content={"message": "file type is not supported"})
# os.makedirs("uploaded_audio", exist_ok=True)
# save_path = os.path.join("uploaded_audio", audio_file.filename)
# # 保存音频文件到服务器上的一个目录
# with open(save_path , "wb") as buffer:
# buffer.write(await audio_file.read())
# tts_pipeline.set_ref_audio(save_path)
# except Exception as e:
# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
# return JSONResponse(status_code=200, content={"message": "success"})
@APP.get("/set_gpt_weights")
async def set_gpt_weights(weights_path: str = None):
try:
payload = tts_engine.set_gpt_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content=payload)
@APP.get("/set_sovits_weights")
async def set_sovits_weights(weights_path: str = None):
try:
payload = tts_engine.set_sovits_weights(weights_path)
except Exception as e:
return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
return JSONResponse(status_code=200, content=payload)
if __name__ == "__main__":
try:
if host == "None": # 在调用时使用 -a None 参数可以让api监听双栈
host = None
uvicorn.run(app=APP, host=host, port=port, workers=1)
except Exception:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM)
exit(0)

1
third_party/g2pw-cu vendored Submodule

@ -0,0 +1 @@
Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import asyncio
import json
import subprocess
import threading
import time
import wave
from pathlib import Path
from typing import Any, Dict, List, Optional
import httpx
ROOT_DIR = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.")
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880")
parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit")
parser.add_argument("--concurrency", type=int, required=True)
parser.add_argument("--timeout-sec", type=float, default=120.0)
parser.add_argument("--server-pid", type=int, default=None)
parser.add_argument("--poll-interval-sec", type=float, default=0.1)
parser.add_argument("--text-lang", type=str, default="zh")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--media-type", type=str, default="wav")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--sample-steps", type=int, default=32)
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav")
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench")
return parser.parse_args()
def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]:
wav_paths_all = sorted(args.wav_dir.glob("*.wav"))
wav_paths: List[Path] = []
for wav_path in wav_paths_all:
with wave.open(str(wav_path), "rb") as handle:
duration = handle.getnframes() / float(handle.getframerate())
if 3.0 <= duration <= 10.0:
wav_paths.append(wav_path)
if not wav_paths:
raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}")
text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if not text_lines:
raise ValueError(f"没有找到有效文本行: {args.text_file}")
requests: List[Dict[str, Any]] = []
for index in range(args.concurrency):
wav_path = wav_paths[index % len(wav_paths)]
lab_path = wav_path.with_suffix(".lab")
if not lab_path.exists():
raise FileNotFoundError(f"缺少参考文本: {lab_path}")
requests.append(
{
"request_id": f"bench_{args.concurrency:03d}_{index:03d}",
"text": text_lines[index % len(text_lines)],
"text_lang": args.text_lang,
"ref_audio_path": str(wav_path),
"prompt_lang": args.prompt_lang,
"prompt_text": lab_path.read_text(encoding="utf-8").strip(),
"top_k": int(args.top_k),
"top_p": float(args.top_p),
"temperature": float(args.temperature),
"repetition_penalty": float(args.repetition_penalty),
"sample_steps": int(args.sample_steps),
"media_type": args.media_type,
"timeout_sec": float(args.timeout_sec),
}
)
return requests
class GpuMemoryPoller:
def __init__(self, server_pid: Optional[int], interval_sec: float):
self.server_pid = server_pid
self.interval_sec = interval_sec
self._stop = threading.Event()
self.samples: List[Dict[str, Any]] = []
self.thread: Optional[threading.Thread] = None
def _query_memory_mb(self) -> Optional[int]:
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-compute-apps=pid,used_gpu_memory",
"--format=csv,noheader,nounits",
],
check=True,
capture_output=True,
text=True,
)
except Exception:
return None
total = 0
found = False
for line in result.stdout.splitlines():
line = line.strip()
if not line:
continue
parts = [item.strip() for item in line.split(",")]
if len(parts) != 2:
continue
try:
pid = int(parts[0])
used_mb = int(parts[1])
except ValueError:
continue
if self.server_pid is None or pid == self.server_pid:
total += used_mb
found = True
if self.server_pid is None:
return total
return total if found else 0
def _run(self) -> None:
while not self._stop.is_set():
used_mb = self._query_memory_mb()
self.samples.append({"ts": time.time(), "used_mb": used_mb})
self._stop.wait(self.interval_sec)
def start(self) -> None:
self.thread = threading.Thread(target=self._run, daemon=True)
self.thread.start()
def stop(self) -> None:
self._stop.set()
if self.thread is not None:
self.thread.join(timeout=2.0)
def summary(self) -> Dict[str, Any]:
valid = [item for item in self.samples if item["used_mb"] is not None]
peak = max(valid, key=lambda item: item["used_mb"]) if valid else None
first = valid[0] if valid else None
last = valid[-1] if valid else None
return {
"server_pid": self.server_pid,
"sample_count": int(len(self.samples)),
"start_used_mb": None if first is None else int(first["used_mb"]),
"peak_used_mb": None if peak is None else int(peak["used_mb"]),
"peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]),
"end_used_mb": None if last is None else int(last["used_mb"]),
"peak_ts": None if peak is None else float(peak["ts"]),
"samples": self.samples,
}
async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
started = time.perf_counter()
try:
response = await client.post(url, json=payload)
elapsed_ms = (time.perf_counter() - started) * 1000.0
item = {
"request_id": payload["request_id"],
"status_code": int(response.status_code),
"elapsed_ms": float(elapsed_ms),
"content_type": response.headers.get("content-type"),
"audio_bytes": int(len(response.content)),
"headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")},
}
if response.status_code != 200:
try:
item["error_body"] = response.json()
except Exception:
item["error_body"] = response.text
return item
except Exception as exc:
return {
"request_id": payload["request_id"],
"status_code": -1,
"elapsed_ms": float((time.perf_counter() - started) * 1000.0),
"exception": repr(exc),
}
async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
payloads = load_requests(args)
url = args.base_url.rstrip("/") + args.endpoint
poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec)
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0)
started = time.perf_counter()
poller.start()
try:
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads])
finally:
poller.stop()
wall_ms = (time.perf_counter() - started) * 1000.0
ok_results = [item for item in results if item["status_code"] == 200]
failed_results = [item for item in results if item["status_code"] != 200]
request_total_ms = []
worker_total_ms = []
for item in ok_results:
headers = item.get("headers", {})
if "x-request-total-ms" in headers:
request_total_ms.append(float(headers["x-request-total-ms"]))
if "x-worker-total-ms" in headers:
worker_total_ms.append(float(headers["x-worker-total-ms"]))
return {
"concurrency": int(args.concurrency),
"server_pid": args.server_pid,
"request_count": int(len(payloads)),
"wall_ms": float(wall_ms),
"success_count": int(len(ok_results)),
"failure_count": int(len(failed_results)),
"request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None,
"request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None,
"worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None,
"worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None,
"gpu_memory": poller.summary(),
"results": results,
}
def main() -> None:
args = parse_args()
output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}"
output_dir.mkdir(parents=True, exist_ok=True)
summary = asyncio.run(run_benchmark(args))
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps({
"concurrency": summary["concurrency"],
"success_count": summary["success_count"],
"failure_count": summary["failure_count"],
"wall_ms": summary["wall_ms"],
"gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"],
"request_total_ms_avg": summary["request_total_ms_avg"],
"request_total_ms_max": summary["request_total_ms_max"],
"summary_path": str(summary_path),
}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,887 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import gc
import contextlib
import json
import random
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
if str(gpt_sovits_dir) not in sys.path:
sys.path.append(str(gpt_sovits_dir))
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
SchedulerRequestSpec,
T2SRequestState,
T2SRunningRequest,
_build_decode_batch_from_running,
build_prefill_batch,
prepare_request_state,
run_decode_step_for_running,
run_prefill_step,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
parser.add_argument("--request-manifest", type=Path, default=None)
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
parser.add_argument("--auto-count", type=int, default=4)
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--text", type=str, default=None)
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
parser.add_argument("--text-lang", type=str, default="zh")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--early-stop-num", type=int, default=-1)
parser.add_argument("--max-steps", type=int, default=1500)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--warmup", action="store_true", default=False)
parser.add_argument("--worker-rounds", type=int, default=1)
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
parser.add_argument(
"--output-dir",
type=Path,
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
)
return parser.parse_args()
def set_seed(seed: int, use_cuda: bool) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _sync_device(device: Any) -> None:
try:
device_str = str(device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
def bytes_to_mb(num_bytes: int) -> float:
return float(num_bytes) / (1024.0 * 1024.0)
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
if tensor is None:
return 0
return int(tensor.numel() * tensor.element_size())
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
return int(sum(tensor_nbytes(item) for item in items))
def model_nbytes(module: torch.nn.Module) -> int:
total = 0
for parameter in module.parameters():
total += tensor_nbytes(parameter)
for buffer in module.buffers():
total += tensor_nbytes(buffer)
return int(total)
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
modules = {
"t2s_model": tts.t2s_model,
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
"vits_model": tts.vits_model,
"bert_model": tts.bert_model,
"cnhuhbert_model": tts.cnhuhbert_model,
"vocoder": tts.vocoder,
"sv_model": tts.sv_model,
}
by_module = {}
total_bytes = 0
for name, module in modules.items():
module_bytes = model_nbytes(module) if module is not None else 0
by_module[name] = {
"bytes": int(module_bytes),
"mb": bytes_to_mb(module_bytes),
}
total_bytes += module_bytes
return {
"by_module": by_module,
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
}
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
storages: Dict[int, Dict[str, Any]] = {}
tensor_views: List[Dict[str, Any]] = []
for obj in gc.get_objects():
try:
tensor = None
if torch.is_tensor(obj):
tensor = obj
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
tensor = obj.data
if tensor is None or not tensor.is_cuda:
continue
storage = tensor.untyped_storage()
storage_ptr = int(storage.data_ptr())
if storage_ptr not in storages:
storages[storage_ptr] = {
"storage_ptr": storage_ptr,
"storage_bytes": int(storage.nbytes()),
"dtype": str(tensor.dtype),
"shape": list(tensor.shape),
"device": str(tensor.device),
}
tensor_views.append(
{
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"bytes": tensor_nbytes(tensor),
"device": str(tensor.device),
}
)
except Exception:
continue
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
return {
"unique_storage_count": int(len(storage_list)),
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
"top_storages": storage_list[:top_k],
"top_tensor_views": tensor_views[:top_k],
}
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
return [
SchedulerRequestSpec(
request_id="req_000",
ref_audio_path=args.ref_audio,
prompt_text=args.prompt_text,
prompt_lang=args.prompt_lang,
text=text,
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
]
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
if len(wav_paths) < args.auto_count:
raise ValueError(f"auto wav count不足目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if len(text_lines) < args.auto_count:
raise ValueError(f"auto text lines不足文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
specs: List[SchedulerRequestSpec] = []
for index, wav_path in enumerate(wav_paths):
lab_path = wav_path.with_suffix(".lab")
if not lab_path.exists():
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
specs.append(
SchedulerRequestSpec(
request_id=f"req_{index:03d}",
ref_audio_path=wav_path,
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
prompt_lang="zh",
text=text_lines[index],
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
)
return specs
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
if args.request_manifest is not None:
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
specs: List[SchedulerRequestSpec] = []
for index, item in enumerate(raw_requests):
text = item.get("text")
text_file = item.get("text_file")
if text is None and text_file is None:
raise ValueError(f"request[{index}] must provide text or text_file")
if text is None:
text = Path(text_file).read_text(encoding="utf-8").strip()
specs.append(
SchedulerRequestSpec(
request_id=item.get("request_id", f"req_{index:03d}"),
ref_audio_path=Path(item["ref_audio_path"]),
prompt_text=item["prompt_text"],
prompt_lang=item.get("prompt_lang", "zh"),
text=text,
text_lang=item.get("text_lang", "zh"),
top_k=int(item.get("top_k", args.top_k)),
top_p=float(item.get("top_p", args.top_p)),
temperature=float(item.get("temperature", args.temperature)),
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
ready_step=int(item.get("ready_step", 0)),
)
)
return specs
if args.scenario == "single":
return build_single_spec(args)
return build_auto_specs(args)
def load_pipeline(config_path: Path) -> TTS:
tts_config = TTS_Config(str(config_path))
print(tts_config)
return TTS(tts_config)
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
return {
"allocated_mb": 0.0,
"reserved_mb": 0.0,
"max_allocated_mb": 0.0,
"max_reserved_mb": 0.0,
}
_sync_device(device)
return {
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
}
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.reset_peak_memory_stats(device)
before = cuda_mem_snapshot(device)
started = time.perf_counter()
result = fn()
_sync_device(device)
elapsed_ms = (time.perf_counter() - started) * 1000.0
after = cuda_mem_snapshot(device)
after["elapsed_ms"] = float(elapsed_ms)
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
return result, after
class GlobalPeakRecorder:
def __init__(self, device: Any):
self.device = device
self.checkpoints: List[Dict[str, Any]] = []
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
def record(self, label: str, **extra: Any) -> None:
snapshot = cuda_mem_snapshot(self.device)
snapshot["label"] = label
snapshot.update(extra)
self.checkpoints.append(snapshot)
def summary(self) -> Dict[str, Any]:
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
return {
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
"peak_label": None if peak is None else peak["label"],
"checkpoints": self.checkpoints,
}
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
per_request = []
total = {
"phones_bytes": 0,
"prompt_phones_bytes": 0,
"all_phones_bytes": 0,
"all_bert_features_bytes": 0,
"prompt_semantic_bytes": 0,
"refer_spec_bytes": 0,
"raw_audio_bytes": 0,
"audio_16k_bytes": 0,
}
for state in states:
spec_audio, audio_16k = state.refer_spec
item = {
"request_id": state.request_id,
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"phones_len": int(state.phones.shape[0]),
"all_phones_len": int(state.all_phones.shape[0]),
"bert_frames": int(state.all_bert_features.shape[-1]),
"phones_bytes": tensor_nbytes(state.phones),
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
"all_phones_bytes": tensor_nbytes(state.all_phones),
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
"refer_spec_bytes": tensor_nbytes(spec_audio),
"audio_16k_bytes": tensor_nbytes(audio_16k),
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
}
for key in total:
total[key] += int(item[key])
per_request.append(item)
total["total_bytes"] = int(sum(total.values()))
total["total_mb"] = bytes_to_mb(total["total_bytes"])
return {"per_request": per_request, "total": total}
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
fields = {
"x_bytes": tensor_nbytes(active_batch.x),
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
"y_sequence_bytes": y_sequence_bytes,
}
fields["total_bytes"] = int(sum(fields.values()))
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
fields["batch_size"] = int(len(active_batch.states))
fields["max_x_len"] = int(active_batch.x.shape[1])
fields["src_len"] = int(active_batch.xy_pos.shape[1])
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
return fields
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
per_request = []
total_private_k_bytes = 0
total_private_v_bytes = 0
total_decode_mask_bytes = 0
total_y_sequence_bytes = 0
for item in running_requests:
k_bytes = tensor_list_nbytes(item.k_cache)
v_bytes = tensor_list_nbytes(item.v_cache)
mask_bytes = tensor_nbytes(item.decode_attn_mask)
y_bytes = tensor_nbytes(item.y_sequence)
total_private_k_bytes += k_bytes
total_private_v_bytes += v_bytes
total_decode_mask_bytes += mask_bytes
total_y_sequence_bytes += y_bytes
per_request.append(
{
"request_id": item.state.request_id,
"step_idx": int(item.step_idx),
"prefix_len": int(item.prefix_len),
"history_len": int(item.y_sequence.shape[0]),
"kv_len": int(item.k_cache[0].shape[1]),
"k_cache_bytes": k_bytes,
"v_cache_bytes": v_bytes,
"decode_mask_bytes": mask_bytes,
"y_sequence_bytes": y_bytes,
}
)
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
return {
"per_request": per_request,
"totals": {
"private_k_cache_bytes": int(total_private_k_bytes),
"private_v_cache_bytes": int(total_private_v_bytes),
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
"decode_mask_bytes": int(total_decode_mask_bytes),
"y_sequence_bytes": int(total_y_sequence_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
},
}
def summarise_decode_batch(
xy_pos: torch.Tensor,
batched_k_cache: Sequence[torch.Tensor],
batched_v_cache: Sequence[torch.Tensor],
batched_decode_attn_mask: Optional[torch.Tensor],
running_requests: Sequence[T2SRunningRequest],
) -> Dict[str, Any]:
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
xy_pos_bytes = tensor_nbytes(xy_pos)
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
return {
"batch_size": int(len(running_requests)),
"xy_pos_bytes": int(xy_pos_bytes),
"batched_k_cache_bytes": int(batched_k_bytes),
"batched_v_cache_bytes": int(batched_v_bytes),
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
"batched_decode_mask_bytes": int(batched_mask_bytes),
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_pos_shape": list(xy_pos.shape),
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
"layer_k_cache_shape": list(batched_k_cache[0].shape),
}
def summarise_decode_outputs(
xy_dec: torch.Tensor,
next_k_cache: Sequence[torch.Tensor],
next_v_cache: Sequence[torch.Tensor],
) -> Dict[str, Any]:
xy_dec_bytes = tensor_nbytes(xy_dec)
next_k_bytes = tensor_list_nbytes(next_k_cache)
next_v_bytes = tensor_list_nbytes(next_v_cache)
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
return {
"xy_dec_bytes": int(xy_dec_bytes),
"next_k_cache_bytes": int(next_k_bytes),
"next_v_cache_bytes": int(next_v_bytes),
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_dec_shape": list(xy_dec.shape),
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
}
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
ranking = [
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
]
ranking.sort(key=lambda item: item[1], reverse=True)
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
phones = state.phones.unsqueeze(0).to(tts.configs.device)
audio_fragment = tts.synthesize_audio_request_local(
semantic_tokens=semantic_tokens,
phones=phones,
prompt_semantic=state.prompt_semantic,
prompt_phones=state.prompt_phones,
refer_spec=state.refer_spec,
raw_audio=state.raw_audio,
raw_sr=state.raw_sr,
speed=1.0,
sample_steps=32,
)
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
return tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=1.0,
split_bucket=False,
fragment_interval=0.0,
super_sampling=False,
)
def simulate_worker_end_to_end(
tts: TTS,
specs: Sequence[SchedulerRequestSpec],
max_steps: int,
rounds: int,
grad_mode: str = "default",
) -> Dict[str, Any]:
device = tts.configs.device
recorder = GlobalPeakRecorder(device)
recorder.record("after_model_load")
state_map: Dict[str, T2SRequestState] = {}
per_round: List[Dict[str, Any]] = []
for round_index in range(rounds):
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
with grad_context():
states = [prepare_request_state(tts, spec) for spec in specs]
state_map = {state.request_id: state for state in states}
recorder.record(
"after_prepare_states",
round_index=int(round_index),
request_count=int(len(states)),
grad_mode=grad_mode,
)
pending = list(states)
running_requests: List[T2SRunningRequest] = []
round_events: List[Dict[str, Any]] = []
current_tick = 0
while pending or running_requests:
admitted = pending
pending = []
if admitted:
recorder.record(
"before_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_count=int(len(admitted)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
recorder.record(
"after_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_running_count=int(len(admitted_running)),
admitted_finished_count=int(len(admitted_finished)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "prefill",
"admitted_count": int(len(admitted)),
"admitted_running_count": int(len(admitted_running)),
"admitted_finished_count": int(len(admitted_finished)),
}
)
for item in admitted_finished:
recorder.record(
"before_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
running_requests.extend(admitted_running)
recorder.record(
"after_extend_running",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
if running_requests:
recorder.record(
"before_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
running_requests, step_finished = run_decode_step_for_running(
tts.t2s_model.model,
running_requests,
max_steps=max_steps,
)
recorder.record(
"after_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_count=int(len(step_finished)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "decode",
"running_count_after_decode": int(len(running_requests)),
"finished_count": int(len(step_finished)),
}
)
for item in step_finished:
recorder.record(
"before_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
current_tick += 1
recorder.record(
"after_round_complete",
round_index=int(round_index),
running_count=0,
grad_mode=grad_mode,
)
per_round.append(
{
"round_index": int(round_index),
"events": round_events,
}
)
return {
"grad_mode": grad_mode,
"rounds": per_round,
"timeline": recorder.summary(),
}
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
tts = load_pipeline(args.config)
model = tts.t2s_model.model
device = tts.configs.device
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
set_seed(args.seed, use_cuda)
specs = load_request_specs(args)
if args.early_stop_num == -1:
for spec in specs:
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
if args.warmup and specs:
warmup_spec = specs[:1]
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
gc.collect()
if use_cuda:
torch.cuda.empty_cache()
_sync_device(device)
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
request_state_summary = summarise_state_tensors(states)
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
running_requests, finished_items = prefill_result
running_requests_summary = summarise_running_requests(running_requests)
finished_after_prefill_summary = [
{
"request_id": item.request_id,
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
}
for item in finished_items
]
if not running_requests:
raise RuntimeError(f"prefill 后没有 running requests全部在首步结束: {[item.request_id for item in finished_items]}")
decode_batch_result, decode_batch_mem = stage_run(
device,
lambda: _build_decode_batch_from_running(model, running_requests),
)
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
decode_batch_tensor_summary = summarise_decode_batch(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
running_requests,
)
decode_out_result, decode_step_mem = stage_run(
device,
lambda: model.t2s_transformer.decode_next_token(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
),
)
xy_dec, next_k_cache, next_v_cache = decode_out_result
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
del active_batch
del running_requests
del finished_items
del xy_pos
del batched_k_cache
del batched_v_cache
del batched_decode_attn_mask
del xy_dec
del next_k_cache
del next_v_cache
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
end_to_end_worker = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode=args.worker_grad_mode,
)
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
worker_inference_mode = None
if args.compare_worker_grad_modes:
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
worker_inference_mode = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode="inference_mode",
)
summary = {
"meta": {
"scenario": args.scenario if args.request_manifest is None else "manifest",
"seed": int(args.seed),
"device": str(device),
"dtype": str(next(model.parameters()).dtype),
"request_count": int(len(specs)),
"num_layers": int(model.num_layers),
"num_heads": int(model.num_head),
"model_dim": int(model.model_dim),
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
},
"loaded_module_weights": build_module_weight_summary(tts),
"requests": [
{
"request_id": spec.request_id,
"ref_audio_path": str(spec.ref_audio_path),
"prompt_text": spec.prompt_text,
"text": spec.text,
}
for spec in specs
],
"prepare_stage": {
"memory": prepare_mem,
"request_state": request_state_summary,
},
"prefill_batch": {
"memory": prefill_batch_mem,
"tensor_bytes": prefill_batch_tensor_summary,
},
"prefill_step": {
"memory": prefill_step_mem,
"running_requests": running_requests_summary,
"finished_after_prefill": finished_after_prefill_summary,
},
"decode_batch": {
"memory": decode_batch_mem,
"tensor_bytes": decode_batch_tensor_summary,
},
"decode_outputs": {
"memory": decode_step_mem,
"tensor_bytes": decode_output_tensor_summary,
},
"end_to_end_worker": end_to_end_worker,
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
"end_to_end_worker_inference_mode": worker_inference_mode,
}
summary["top_rankings"] = top_rankings(summary)
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
print("[top_rankings]")
for item in summary["top_rankings"]:
print(f"- {item['name']}: {item['mb']:.3f} MB")
print("[worker_peak]")
print(
json.dumps(
{
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
if worker_inference_mode is not None:
print("[worker_peak_inference_mode]")
print(
json.dumps(
{
"peak_label": worker_inference_mode["timeline"]["peak_label"],
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
print(f"[summary] {summary_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,180 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import json
import random
import sys
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
import torch
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
if str(gpt_sovits_dir) not in sys.path:
sys.path.append(str(gpt_sovits_dir))
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
SchedulerRequestSpec,
T2SFinishedItem,
T2SRequestState,
prepare_request_state,
run_scheduler_continuous,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.")
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
parser.add_argument("--request-manifest", type=Path, default=None)
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
parser.add_argument("--text", type=str, default=None)
parser.add_argument("--text-lang", type=str, default="en")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--early-stop-num", type=int, default=-1)
parser.add_argument("--max-steps", type=int, default=1500)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run")
return parser.parse_args()
def set_seed(seed: int, use_cuda: bool) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_pipeline(config_path: Path):
try:
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。"
) from exc
tts_config = TTS_Config(str(config_path))
print(tts_config)
return TTS(tts_config)
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
if args.request_manifest is not None:
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
specs: List[SchedulerRequestSpec] = []
for index, item in enumerate(raw_requests):
text = item.get("text")
text_file = item.get("text_file")
if text is None and text_file is None:
raise ValueError(f"request[{index}] must provide text or text_file")
if text is None:
text = Path(text_file).read_text(encoding="utf-8")
specs.append(
SchedulerRequestSpec(
request_id=item.get("request_id", f"req_{index:03d}"),
ref_audio_path=Path(item["ref_audio_path"]),
prompt_text=item["prompt_text"],
prompt_lang=item.get("prompt_lang", "zh"),
text=text,
text_lang=item.get("text_lang", "zh"),
top_k=int(item.get("top_k", args.top_k)),
top_p=float(item.get("top_p", args.top_p)),
temperature=float(item.get("temperature", args.temperature)),
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
ready_step=int(item.get("ready_step", 0)),
)
)
return specs
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8")
return [
SchedulerRequestSpec(
request_id="req_000",
ref_audio_path=args.ref_audio,
prompt_text=args.prompt_text,
prompt_lang=args.prompt_lang,
text=text,
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
]
def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]:
return [
{
"request_id": state.request_id,
"ready_step": int(state.ready_step),
"ref_audio_path": str(state.ref_audio_path),
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"all_phone_len": int(state.all_phones.shape[0]),
"bert_len": int(state.all_bert_features.shape[-1]),
"norm_text": state.norm_text,
}
for state in states
]
def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]:
return [
{
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
}
for item in items
]
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
tts = load_pipeline(args.config)
model = tts.t2s_model.model
use_cuda = str(tts.configs.device).startswith("cuda")
set_seed(args.seed, use_cuda)
request_specs = load_request_specs(args)
states = [prepare_request_state(tts, spec) for spec in request_specs]
finished = run_scheduler_continuous(model, states, max_steps=args.max_steps)
summary = {
"request_count": len(states),
"max_steps": args.max_steps,
"requests": summarise_requests(states),
"finished": summarise_finished(finished),
}
output_path = args.output_dir / "scheduler_prototype_summary.json"
output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary, ensure_ascii=False, indent=2))
print(f"[saved] {output_path}")
if __name__ == "__main__":
try:
main()
except ModuleNotFoundError as exc:
print(f"[error] {exc}")
raise SystemExit(1) from None