mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-07 23:28:17 +08:00
Merge 8a444c10b72aadeb1f8be9210b45f2a519d383e4 into 2d9193b0d3c0eae0c3a14d8c68a839f1bae157dc
This commit is contained in:
commit
d54ea4d5aa
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "third_party/g2pw-cu"]
|
||||
path = third_party/g2pw-cu
|
||||
url = https://github.com/baicai-1145/g2pw-cu.git
|
||||
@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
blocks.append(block)
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
self.last_infer_stats = {}
|
||||
|
||||
def _set_last_infer_stats(self, stats):
|
||||
self.last_infer_stats = stats
|
||||
|
||||
def get_last_infer_stats(self):
|
||||
return dict(self.last_infer_stats)
|
||||
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||
x = self.ar_text_embedding(x)
|
||||
@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module):
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs,
|
||||
):
|
||||
requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True))
|
||||
if prompts is None:
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "batch_infer_prompt_free_fallback",
|
||||
"requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath,
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": None,
|
||||
"fastpath_hit": False,
|
||||
"generated_token_count": 0,
|
||||
"generated_token_count_list": [],
|
||||
}
|
||||
)
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(
|
||||
x,
|
||||
@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
)
|
||||
|
||||
max_len = kwargs.get("max_len", x_lens.max())
|
||||
enable_mask_free_fastpath = requested_enable_mask_free_fastpath
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None] * y.shape[0]
|
||||
decode_attn_mask = attn_mask
|
||||
prefill_after_mask_all_visible = None
|
||||
fastpath_hit = False
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(
|
||||
xy_pos, k_cache, v_cache, decode_attn_mask
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
prefill_after_mask_all_visible = not attn_mask.any().item()
|
||||
if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible:
|
||||
decode_attn_mask = None
|
||||
fastpath_hit = True
|
||||
else:
|
||||
decode_attn_mask = attn_mask
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
if decode_attn_mask is not None:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
decode_attn_mask = attn_mask
|
||||
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if decode_attn_mask is not None:
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
decode_attn_mask = attn_mask
|
||||
if k_cache is not None:
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "batch_infer",
|
||||
"requested_enable_mask_free_fastpath": enable_mask_free_fastpath,
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": prefill_after_mask_all_visible,
|
||||
"fastpath_hit": fastpath_hit,
|
||||
"generated_token_count": int(sum(idx_list)),
|
||||
"generated_token_count_list": [int(item) for item in idx_list],
|
||||
"max_len": int(max_len),
|
||||
}
|
||||
)
|
||||
if ref_free:
|
||||
return y_list, [0] * x.shape[0]
|
||||
# print(idx_list)
|
||||
@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "naive_batched",
|
||||
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": None,
|
||||
"fastpath_hit": False,
|
||||
"generated_token_count": int(sum(idx_list)),
|
||||
"generated_token_count_list": [int(item) for item in idx_list],
|
||||
}
|
||||
)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive(
|
||||
@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
|
||||
if not streaming_mode:
|
||||
generated_token_count = max(int(y.shape[1] - prefix_len), 0)
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "naive",
|
||||
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
|
||||
"batch_size": int(x.shape[0]),
|
||||
"prefill_after_mask_all_visible": True if prompts is not None else None,
|
||||
"fastpath_hit": True if prompts is not None else False,
|
||||
"generated_token_count": generated_token_count,
|
||||
"generated_token_count_list": [generated_token_count],
|
||||
}
|
||||
)
|
||||
if ref_free:
|
||||
yield y, 0
|
||||
yield y, idx
|
||||
|
||||
@ -147,6 +147,7 @@ def multinomial_sample_one_no_sync(
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
previous_token_mask: Optional[torch.Tensor] = None,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
@ -158,13 +159,27 @@ def logits_to_probs(
|
||||
# pdb.set_trace()
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
if previous_token_mask is None:
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
else:
|
||||
previous_token_mask = previous_token_mask.to(dtype=torch.bool, device=logits.device)
|
||||
if previous_token_mask.any():
|
||||
batch_index = torch.arange(logits.size(0), device=logits.device).unsqueeze(1).expand_as(previous_tokens)
|
||||
valid_batch_index = batch_index[previous_token_mask]
|
||||
valid_token_index = previous_tokens[previous_token_mask]
|
||||
score = logits[valid_batch_index, valid_token_index]
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits[valid_batch_index, valid_token_index] = score
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
@ -192,9 +207,15 @@ def logits_to_probs(
|
||||
def sample(
|
||||
logits,
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
previous_token_mask: Optional[torch.Tensor] = None,
|
||||
**sampling_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
||||
probs = logits_to_probs(
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
previous_token_mask=previous_token_mask,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,10 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -11,11 +15,13 @@ import re
|
||||
import torch
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from text import chinese
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
|
||||
from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
@ -49,12 +55,80 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
class StageLimiter:
|
||||
def __init__(self, slots: int):
|
||||
self.slots = max(1, int(slots))
|
||||
self.semaphore = threading.BoundedSemaphore(self.slots)
|
||||
self.lock = threading.Lock()
|
||||
self.inflight = 0
|
||||
self.peak_inflight = 0
|
||||
|
||||
@contextmanager
|
||||
def enter(self):
|
||||
wait_start = time.perf_counter()
|
||||
self.semaphore.acquire()
|
||||
wait_ms = (time.perf_counter() - wait_start) * 1000.0
|
||||
with self.lock:
|
||||
self.inflight += 1
|
||||
current_inflight = self.inflight
|
||||
if current_inflight > self.peak_inflight:
|
||||
self.peak_inflight = current_inflight
|
||||
peak_inflight = self.peak_inflight
|
||||
try:
|
||||
yield {
|
||||
"wait_ms": wait_ms,
|
||||
"inflight": current_inflight,
|
||||
"peak_inflight": peak_inflight,
|
||||
"slots": self.slots,
|
||||
}
|
||||
finally:
|
||||
with self.lock:
|
||||
self.inflight = max(0, self.inflight - 1)
|
||||
self.semaphore.release()
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.lock:
|
||||
return {
|
||||
"slots": self.slots,
|
||||
"inflight": self.inflight,
|
||||
"peak_inflight": self.peak_inflight,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreparedTextSegment:
|
||||
language: str
|
||||
phones: List[int]
|
||||
word2ph: Optional[List[int]]
|
||||
norm_text: str
|
||||
needs_g2pw: bool = False
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||
def __init__(
|
||||
self,
|
||||
bert_model: AutoModelForMaskedLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
device: torch.device,
|
||||
version: str = "v2",
|
||||
bert_stage_limiter: StageLimiter | None = None,
|
||||
bert_batch_worker: PrepareBertBatchWorker | None = None,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
self.version = str(version)
|
||||
self.bert_stage_limiter = bert_stage_limiter
|
||||
self.bert_batch_worker = bert_batch_worker
|
||||
|
||||
def snapshot(self) -> Dict[str, object]:
|
||||
return {
|
||||
"device": str(self.device),
|
||||
"bert_stage_limiter": (
|
||||
None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot())
|
||||
),
|
||||
"bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()),
|
||||
}
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
@ -98,7 +172,7 @@ class TextPreprocessor:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if len(text.strip()) == 0:
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
if not re.sub(r"\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if text[-1] not in splits:
|
||||
@ -115,86 +189,233 @@ class TextPreprocessor:
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1"
|
||||
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
prepared_segments = self.preprocess_text_segments(text, language, version)
|
||||
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
|
||||
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
with self.bert_lock:
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||
)
|
||||
if same_group:
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text,"ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text,"ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
else:
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
return textlist, langlist
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||
def get_phones_and_bert(
|
||||
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
|
||||
):
|
||||
prepared_segments = self.preprocess_text_segments(text, language, version, final=final)
|
||||
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
|
||||
|
||||
return phones, bert, norm_text
|
||||
def preprocess_text_segments(
|
||||
self,
|
||||
text: str,
|
||||
language: str,
|
||||
version: str,
|
||||
final: bool = False,
|
||||
) -> List[PreparedTextSegment]:
|
||||
payloads = preprocess_text_segments_payload(text, language, version, final=final)
|
||||
return [
|
||||
PreparedTextSegment(
|
||||
language=str(payload["language"]),
|
||||
phones=list(payload["phones"]),
|
||||
word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]),
|
||||
norm_text=str(payload["norm_text"]),
|
||||
needs_g2pw=bool(payload.get("needs_g2pw", False)),
|
||||
)
|
||||
for payload in payloads
|
||||
]
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
def resolve_g2pw_segments(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> List[PreparedTextSegment]:
|
||||
zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)]
|
||||
if not zh_indices:
|
||||
return prepared_segments
|
||||
from text import chinese2
|
||||
|
||||
normalized_segments = [prepared_segments[index].norm_text for index in zh_indices]
|
||||
resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True)
|
||||
self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0))
|
||||
self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0))
|
||||
self._accumulate_profile(
|
||||
profile,
|
||||
"g2pw_runtime_collect_wait_ms",
|
||||
g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0),
|
||||
)
|
||||
self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0))
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"g2pw_runtime_batch_rows_peak",
|
||||
g2pw_profile.get("g2pw_runtime_batch_rows", 0.0),
|
||||
)
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"g2pw_runtime_batch_requests_peak",
|
||||
g2pw_profile.get("g2pw_runtime_batch_requests", 0.0),
|
||||
)
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"g2pw_runtime_pool_workers",
|
||||
g2pw_profile.get("g2pw_runtime_pool_workers", 0.0),
|
||||
)
|
||||
for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments):
|
||||
prepared_segments[index] = PreparedTextSegment(
|
||||
language=prepared_segments[index].language,
|
||||
phones=list(cleaned_text_to_sequence(phones, self.version)),
|
||||
word2ph=None if word2ph is None else list(word2ph),
|
||||
norm_text=str(norm_text),
|
||||
needs_g2pw=False,
|
||||
)
|
||||
return prepared_segments
|
||||
|
||||
def build_phones_and_bert_from_segments(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
|
||||
phones_list: List[List[int]] = []
|
||||
bert_list: List[torch.Tensor] = []
|
||||
norm_text_list: List[str] = []
|
||||
for segment in prepared_segments:
|
||||
bert = self.get_bert_inf(
|
||||
segment.phones,
|
||||
segment.word2ph,
|
||||
segment.norm_text,
|
||||
segment.language,
|
||||
profile=profile,
|
||||
)
|
||||
phones_list.append(segment.phones)
|
||||
norm_text_list.append(segment.norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
return phones, bert, norm_text
|
||||
|
||||
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
|
||||
if profile is None:
|
||||
return
|
||||
profile[key] = float(profile.get(key, 0.0)) + float(value)
|
||||
|
||||
def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None:
|
||||
if profile is None:
|
||||
return
|
||||
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
|
||||
|
||||
def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None:
|
||||
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0))
|
||||
self._accumulate_profile(
|
||||
profile,
|
||||
"bert_batch_collect_wait_ms",
|
||||
worker_profile.get("bert_batch_collect_wait_ms", 0.0),
|
||||
)
|
||||
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
|
||||
self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0))
|
||||
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
|
||||
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"bert_pending_depth_on_enqueue_peak",
|
||||
worker_profile.get("bert_pending_depth_on_enqueue", 0.0),
|
||||
)
|
||||
self._update_profile_peak(
|
||||
profile,
|
||||
"bert_pending_depth_on_collect_peak",
|
||||
worker_profile.get("bert_pending_depth_on_collect", 0.0),
|
||||
)
|
||||
self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0))
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
|
||||
profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0))
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
|
||||
if self.bert_batch_worker is not None:
|
||||
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
|
||||
self._merge_bert_worker_profile(profile, worker_profile)
|
||||
return feature
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
|
||||
if self.bert_stage_limiter is None:
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.bert_stage_limiter.enter() as limiter_stats:
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"])
|
||||
self._accumulate_profile(profile, "bert_forward_ms", forward_ms)
|
||||
self._accumulate_profile(profile, "bert_calls", 1.0)
|
||||
self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"])
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(limiter_stats["slots"])
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
@ -209,10 +430,19 @@ class TextPreprocessor:
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
def get_bert_inf(
|
||||
self,
|
||||
phones: list,
|
||||
word2ph: Optional[list],
|
||||
norm_text: str,
|
||||
language: str,
|
||||
profile: Dict | None = None,
|
||||
):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
if word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
|
||||
else:
|
||||
feature = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
@ -221,6 +451,115 @@ class TextPreprocessor:
|
||||
|
||||
return feature
|
||||
|
||||
async def build_phones_and_bert_from_segments_async(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None = None,
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
|
||||
segment_jobs = self._build_async_segment_jobs(prepared_segments, profile)
|
||||
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
|
||||
for segment_index, segment in enumerate(prepared_segments):
|
||||
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
|
||||
continue
|
||||
if segment.word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
pending_items.append(
|
||||
(
|
||||
segment_jobs["bert_list"],
|
||||
segment_index,
|
||||
profile,
|
||||
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
|
||||
)
|
||||
)
|
||||
|
||||
if pending_items:
|
||||
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
|
||||
for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
|
||||
self._merge_bert_worker_profile(item_profile, worker_profile)
|
||||
bert_list[bert_index] = feature.to(self.device)
|
||||
|
||||
return self._finalize_async_segment_jobs(segment_jobs)
|
||||
|
||||
def _build_async_segment_jobs(
|
||||
self,
|
||||
prepared_segments: List[PreparedTextSegment],
|
||||
profile: Dict | None,
|
||||
) -> Dict[str, List]:
|
||||
phones_list: List[List[int]] = []
|
||||
bert_list: List[torch.Tensor | None] = []
|
||||
norm_text_list: List[str] = []
|
||||
|
||||
for segment in prepared_segments:
|
||||
phones_list.append(segment.phones)
|
||||
norm_text_list.append(segment.norm_text)
|
||||
segment_language = segment.language.replace("all_", "")
|
||||
if segment_language == "zh" and self.bert_batch_worker is not None:
|
||||
if segment.word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
bert_list.append(None)
|
||||
continue
|
||||
bert_list.append(
|
||||
self.get_bert_inf(
|
||||
segment.phones,
|
||||
segment.word2ph,
|
||||
segment.norm_text,
|
||||
segment.language,
|
||||
profile=profile,
|
||||
)
|
||||
)
|
||||
return {
|
||||
"phones_list": phones_list,
|
||||
"bert_list": bert_list,
|
||||
"norm_text_list": norm_text_list,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]:
|
||||
bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1)
|
||||
phones = sum(segment_jobs["phones_list"], [])
|
||||
norm_text = "".join(segment_jobs["norm_text_list"])
|
||||
return phones, bert, norm_text
|
||||
|
||||
async def build_phones_and_bert_pair_from_segments_async(
|
||||
self,
|
||||
prompt_segments: List[PreparedTextSegment],
|
||||
target_segments: List[PreparedTextSegment],
|
||||
prompt_profile: Dict | None = None,
|
||||
target_profile: Dict | None = None,
|
||||
) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]:
|
||||
prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile)
|
||||
target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile)
|
||||
prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile)
|
||||
target_jobs = self._build_async_segment_jobs(target_segments, target_profile)
|
||||
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
|
||||
|
||||
for segment_jobs, prepared_segments, profile in (
|
||||
(prompt_jobs, prompt_segments, prompt_profile),
|
||||
(target_jobs, target_segments, target_profile),
|
||||
):
|
||||
for segment_index, segment in enumerate(prepared_segments):
|
||||
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
|
||||
continue
|
||||
if segment.word2ph is None:
|
||||
raise ValueError("中文文本缺少 word2ph,无法提取 BERT 特征")
|
||||
pending_items.append(
|
||||
(
|
||||
segment_jobs["bert_list"],
|
||||
segment_index,
|
||||
profile,
|
||||
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
|
||||
)
|
||||
)
|
||||
|
||||
if pending_items:
|
||||
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
|
||||
for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
|
||||
self._merge_bert_worker_profile(profile, worker_profile)
|
||||
bert_list[bert_index] = feature.to(self.device)
|
||||
|
||||
return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs)
|
||||
|
||||
def filter_text(self, texts):
|
||||
_text = []
|
||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||
@ -236,4 +575,4 @@ class TextPreprocessor:
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
return result
|
||||
|
||||
@ -1 +1,11 @@
|
||||
from . import TTS, text_segmentation_method
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in __all__:
|
||||
return importlib.import_module(f"{__name__}.{name}")
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
346
GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py
Normal file
346
GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py
Normal file
@ -0,0 +1,346 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertFeatureTask:
|
||||
norm_text: str
|
||||
word2ph: List[int]
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
enqueued_at: float = 0.0
|
||||
admission_wait_ms: float = 0.0
|
||||
pending_depth_on_enqueue: int = 0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
done_loop: asyncio.AbstractEventLoop | None = None
|
||||
done_future: asyncio.Future | None = None
|
||||
result_feature: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareBertBatchWorker:
|
||||
def __init__(
|
||||
self,
|
||||
bert_model,
|
||||
tokenizer,
|
||||
device,
|
||||
stage_limiter=None,
|
||||
batch_window_ms: int = 5,
|
||||
max_batch_items: int = 16,
|
||||
max_batch_tokens: int = 4096,
|
||||
max_pending_tasks: int = 0,
|
||||
admission_poll_ms: int = 1,
|
||||
high_pressure_pending_threshold: int = 0,
|
||||
high_pressure_batch_window_ms: int | None = None,
|
||||
high_pressure_max_batch_items: int | None = None,
|
||||
high_pressure_max_batch_tokens: int | None = None,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.stage_limiter = stage_limiter
|
||||
self.batch_window_ms = max(0, int(batch_window_ms))
|
||||
self.batch_window_s = float(self.batch_window_ms) / 1000.0
|
||||
self.max_batch_items = max(1, int(max_batch_items))
|
||||
self.max_batch_tokens = max(16, int(max_batch_tokens))
|
||||
self.max_pending_tasks = max(0, int(max_pending_tasks))
|
||||
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
|
||||
|
||||
self.high_pressure_pending_threshold = max(
|
||||
0,
|
||||
int(high_pressure_pending_threshold)
|
||||
if int(high_pressure_pending_threshold) > 0
|
||||
else max(self.max_batch_items * 2, 32),
|
||||
)
|
||||
hp_window_ms = self.batch_window_ms if high_pressure_batch_window_ms is None else int(high_pressure_batch_window_ms)
|
||||
hp_items = self.max_batch_items if high_pressure_max_batch_items is None else int(high_pressure_max_batch_items)
|
||||
hp_tokens = self.max_batch_tokens if high_pressure_max_batch_tokens is None else int(high_pressure_max_batch_tokens)
|
||||
self.high_pressure_batch_window_ms = max(0, hp_window_ms)
|
||||
self.high_pressure_batch_window_s = float(self.high_pressure_batch_window_ms) / 1000.0
|
||||
self.high_pressure_max_batch_items = max(self.max_batch_items, hp_items)
|
||||
self.high_pressure_max_batch_tokens = max(self.max_batch_tokens, hp_tokens)
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[BertFeatureTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.total_batches = 0
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_peak = 0
|
||||
self.active_batch_tokens = 0
|
||||
self.active_batch_tokens_peak = 0
|
||||
self.high_pressure_batches = 0
|
||||
self.admission_wait_total_ms = 0.0
|
||||
self.admission_wait_peak_ms = 0.0
|
||||
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
|
||||
return max(1, len(task.norm_text) + 2)
|
||||
|
||||
def _can_enqueue_locked(self) -> bool:
|
||||
if self.max_pending_tasks <= 0:
|
||||
return True
|
||||
return (len(self.pending_tasks) + self.active_batch_size) < self.max_pending_tasks
|
||||
|
||||
def _record_enqueue_locked(self, task: BertFeatureTask, admission_wait_ms: float) -> None:
|
||||
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
|
||||
task.enqueued_at = time.perf_counter()
|
||||
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
self.admission_wait_total_ms += task.admission_wait_ms
|
||||
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
|
||||
def _enqueue_task(self, task: BertFeatureTask) -> None:
|
||||
admission_started = time.perf_counter()
|
||||
with self.condition:
|
||||
while not self._can_enqueue_locked():
|
||||
self.condition.wait(timeout=self.admission_poll_s)
|
||||
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
|
||||
|
||||
async def _enqueue_task_async(self, task: BertFeatureTask) -> None:
|
||||
admission_started = time.perf_counter()
|
||||
while True:
|
||||
with self.condition:
|
||||
if self._can_enqueue_locked():
|
||||
self._record_enqueue_locked(task, (time.perf_counter() - admission_started) * 1000.0)
|
||||
return
|
||||
await asyncio.sleep(self.admission_poll_s)
|
||||
|
||||
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
|
||||
self._enqueue_task(task)
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.result_feature is not None
|
||||
return task.result_feature, dict(task.profile)
|
||||
|
||||
async def submit_async(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = BertFeatureTask(
|
||||
norm_text=str(norm_text),
|
||||
word2ph=list(word2ph),
|
||||
done_loop=loop,
|
||||
done_future=loop.create_future(),
|
||||
)
|
||||
await self._enqueue_task_async(task)
|
||||
return await task.done_future
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
"pending": len(self.pending_tasks),
|
||||
"pending_peak": self.pending_peak,
|
||||
"total_submitted": self.total_submitted,
|
||||
"total_finished": self.total_finished,
|
||||
"total_batches": self.total_batches,
|
||||
"active_batch_size": self.active_batch_size,
|
||||
"active_batch_peak": self.active_batch_peak,
|
||||
"active_batch_tokens": self.active_batch_tokens,
|
||||
"active_batch_tokens_peak": self.active_batch_tokens_peak,
|
||||
"batch_window_ms": int(self.batch_window_s * 1000.0),
|
||||
"max_batch_items": self.max_batch_items,
|
||||
"max_batch_tokens": self.max_batch_tokens,
|
||||
"max_pending_tasks": self.max_pending_tasks,
|
||||
"high_pressure_pending_threshold": self.high_pressure_pending_threshold,
|
||||
"high_pressure_batch_window_ms": self.high_pressure_batch_window_ms,
|
||||
"high_pressure_max_batch_items": self.high_pressure_max_batch_items,
|
||||
"high_pressure_max_batch_tokens": self.high_pressure_max_batch_tokens,
|
||||
"high_pressure_batches": self.high_pressure_batches,
|
||||
"admission_wait_total_ms": self.admission_wait_total_ms,
|
||||
"admission_wait_peak_ms": self.admission_wait_peak_ms,
|
||||
}
|
||||
|
||||
def _select_batch_policy_locked(self) -> Tuple[float, int, int, bool, int]:
|
||||
pending_depth = len(self.pending_tasks)
|
||||
use_high_pressure = (
|
||||
self.high_pressure_pending_threshold > 0
|
||||
and pending_depth >= self.high_pressure_pending_threshold
|
||||
)
|
||||
if use_high_pressure:
|
||||
return (
|
||||
self.high_pressure_batch_window_s,
|
||||
self.high_pressure_max_batch_items,
|
||||
self.high_pressure_max_batch_tokens,
|
||||
True,
|
||||
pending_depth,
|
||||
)
|
||||
return (
|
||||
self.batch_window_s,
|
||||
self.max_batch_items,
|
||||
self.max_batch_tokens,
|
||||
False,
|
||||
pending_depth,
|
||||
)
|
||||
|
||||
def _collect_batch(self) -> Tuple[List[BertFeatureTask], Dict[str, float]]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
collect_started = time.perf_counter()
|
||||
batch_window_s, max_batch_items, max_batch_tokens, use_high_pressure, pending_depth_on_collect = (
|
||||
self._select_batch_policy_locked()
|
||||
)
|
||||
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
|
||||
batch_tokens = self._estimate_task_tokens(batch[0])
|
||||
deadline = time.perf_counter() + batch_window_s
|
||||
|
||||
while len(batch) < max_batch_items:
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.condition.wait(timeout=remaining)
|
||||
continue
|
||||
next_task = self.pending_tasks[0]
|
||||
next_tokens = self._estimate_task_tokens(next_task)
|
||||
if len(batch) >= max_batch_items or (batch_tokens + next_tokens) > max_batch_tokens:
|
||||
break
|
||||
batch.append(self.pending_tasks.popleft())
|
||||
batch_tokens += next_tokens
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
self.active_batch_tokens = batch_tokens
|
||||
if self.active_batch_size > self.active_batch_peak:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
if self.active_batch_tokens > self.active_batch_tokens_peak:
|
||||
self.active_batch_tokens_peak = self.active_batch_tokens
|
||||
if use_high_pressure:
|
||||
self.high_pressure_batches += 1
|
||||
return batch, {
|
||||
"collect_wait_ms": (time.perf_counter() - collect_started) * 1000.0,
|
||||
"batch_tokens": float(batch_tokens),
|
||||
"pending_depth_on_collect": float(pending_depth_on_collect),
|
||||
"high_pressure_mode": 1.0 if use_high_pressure else 0.0,
|
||||
"batch_window_ms": float(self.high_pressure_batch_window_ms if use_high_pressure else self.batch_window_ms),
|
||||
}
|
||||
|
||||
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
|
||||
with self.condition:
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_tokens = 0
|
||||
self.total_batches += 1
|
||||
self.total_finished += len(batch)
|
||||
self.condition.notify_all()
|
||||
|
||||
def _run_batch(self, batch: List[BertFeatureTask], batch_meta: Dict[str, float]) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
texts = [task.norm_text for task in batch]
|
||||
batch_tokens = int(batch_meta["batch_tokens"])
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
if self.stage_limiter is None:
|
||||
tokenize_start = time.perf_counter()
|
||||
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
||||
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
|
||||
attention_mask_cpu = inputs["attention_mask"].cpu()
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.device)
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.stage_limiter.enter() as limiter_stats:
|
||||
tokenize_start = time.perf_counter()
|
||||
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
||||
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
|
||||
attention_mask_cpu = inputs["attention_mask"].cpu()
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.device)
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
|
||||
hidden = outputs["hidden_states"][-3].detach().cpu()
|
||||
scatter_start = time.perf_counter()
|
||||
for batch_index, task in enumerate(batch):
|
||||
try:
|
||||
text_len = len(task.word2ph)
|
||||
if text_len != len(task.norm_text):
|
||||
raise AssertionError(
|
||||
f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}"
|
||||
)
|
||||
seq_len = int(attention_mask_cpu[batch_index].sum().item())
|
||||
char_features = hidden[batch_index, 1 : seq_len - 1]
|
||||
if char_features.shape[0] != text_len:
|
||||
raise AssertionError(
|
||||
f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}"
|
||||
)
|
||||
phone_level_feature = []
|
||||
for char_index, repeat_count in enumerate(task.word2ph):
|
||||
phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1))
|
||||
task.result_feature = torch.cat(phone_level_feature, dim=0).T
|
||||
task.profile = {
|
||||
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
||||
"bert_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"bert_queue_wait_ms": max(0.0, (batch_started - task.enqueued_at) * 1000.0),
|
||||
"bert_batch_collect_wait_ms": float(batch_meta["collect_wait_ms"]),
|
||||
"bert_forward_ms": float(forward_ms),
|
||||
"bert_tokenize_ms": float(tokenize_ms),
|
||||
"bert_scatter_ms": 0.0,
|
||||
"bert_calls": 1.0,
|
||||
"bert_stage_slots": float(limiter_stats["slots"]),
|
||||
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"bert_batch_size": float(len(batch)),
|
||||
"bert_batch_tokens": float(batch_tokens),
|
||||
"bert_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
|
||||
"bert_pending_depth_on_collect": float(batch_meta["pending_depth_on_collect"]),
|
||||
"bert_high_pressure_mode": float(batch_meta["high_pressure_mode"]),
|
||||
"bert_batch_window_ms": float(batch_meta["batch_window_ms"]),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
|
||||
for task in batch:
|
||||
if task.result_feature is not None:
|
||||
task.profile["bert_scatter_ms"] = float(scatter_ms)
|
||||
task.done_event.set()
|
||||
self._notify_done_future(task)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(task: BertFeatureTask) -> None:
|
||||
if task.done_future is None or task.done_future.done():
|
||||
return
|
||||
if task.error is not None:
|
||||
task.done_future.set_exception(task.error)
|
||||
return
|
||||
assert task.result_feature is not None
|
||||
task.done_future.set_result((task.result_feature, dict(task.profile)))
|
||||
|
||||
def _notify_done_future(self, task: BertFeatureTask) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch, batch_meta = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch, batch_meta)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
task.done_event.set()
|
||||
self._notify_done_future(task)
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
1066
GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py
Normal file
1066
GPT_SoVITS/TTS_infer_pack/prepare_coordinator.py
Normal file
File diff suppressed because it is too large
Load Diff
382
GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py
Normal file
382
GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py
Normal file
@ -0,0 +1,382 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
REF_AUDIO_MIN_SAMPLES_16K = 48000
|
||||
REF_AUDIO_MAX_SAMPLES_16K = 160000
|
||||
_RESAMPLE_CACHE_LOCK = threading.Lock()
|
||||
_RESAMPLE_CACHE: Dict[Tuple[int, int, str], torchaudio.transforms.Resample] = {}
|
||||
_RESAMPLE_STREAM_CACHE: Dict[str, torch.cuda.Stream] = {}
|
||||
|
||||
|
||||
def _get_resampler(orig_sr: int, target_sr: int, device: str) -> torchaudio.transforms.Resample:
|
||||
device_key = str(device)
|
||||
key = (int(orig_sr), int(target_sr), device_key)
|
||||
with _RESAMPLE_CACHE_LOCK:
|
||||
transform = _RESAMPLE_CACHE.get(key)
|
||||
if transform is None:
|
||||
transform = torchaudio.transforms.Resample(orig_freq=int(orig_sr), new_freq=int(target_sr)).to(device_key)
|
||||
_RESAMPLE_CACHE[key] = transform
|
||||
return transform
|
||||
|
||||
|
||||
def _get_resample_stream(device: str) -> torch.cuda.Stream:
|
||||
device_key = str(device)
|
||||
with _RESAMPLE_CACHE_LOCK:
|
||||
stream = _RESAMPLE_STREAM_CACHE.get(device_key)
|
||||
if stream is None:
|
||||
stream = torch.cuda.Stream(device=device_key)
|
||||
_RESAMPLE_STREAM_CACHE[device_key] = stream
|
||||
return stream
|
||||
|
||||
|
||||
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
|
||||
resample_device = os.environ.get("GPTSOVITS_PREPARE_REF_RESAMPLE_DEVICE", "cpu").strip().lower() or "cpu"
|
||||
if resample_device not in {"cpu", "cuda"}:
|
||||
resample_device = "cpu"
|
||||
if resample_device == "cuda" and not torch.cuda.is_available():
|
||||
resample_device = "cpu"
|
||||
wav_mono = raw_audio
|
||||
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
|
||||
wav_mono = wav_mono.mean(0, keepdim=True)
|
||||
if resample_device == "cuda":
|
||||
stream = _get_resample_stream(resample_device)
|
||||
with torch.cuda.stream(stream):
|
||||
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
|
||||
if raw_sr != 16000:
|
||||
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
|
||||
wav16k = wav16k.squeeze(0).contiguous()
|
||||
stream.synchronize()
|
||||
wav16k = wav16k.detach().to(device="cpu", dtype=torch.float32).contiguous()
|
||||
else:
|
||||
wav16k = wav_mono.to(dtype=torch.float32, device=resample_device)
|
||||
if raw_sr != 16000:
|
||||
wav16k = _get_resampler(int(raw_sr), 16000, resample_device)(wav16k)
|
||||
wav16k = wav16k.squeeze(0).contiguous()
|
||||
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
|
||||
raise OSError("参考音频在3~10秒范围外,请更换!")
|
||||
if zero_wav_samples > 0:
|
||||
wav16k = torch.cat(
|
||||
[wav16k, torch.zeros(int(zero_wav_samples), dtype=torch.float32, device=wav16k.device)],
|
||||
dim=0,
|
||||
)
|
||||
return wav16k.contiguous()
|
||||
|
||||
|
||||
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:
|
||||
if conv1d is None:
|
||||
return input_lengths.to(dtype=torch.long)
|
||||
kernel_size = int(conv1d.kernel_size[0])
|
||||
stride = int(conv1d.stride[0])
|
||||
padding = int(conv1d.padding[0])
|
||||
dilation = int(conv1d.dilation[0])
|
||||
output_lengths = torch.div(
|
||||
input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
|
||||
stride,
|
||||
rounding_mode="floor",
|
||||
) + 1
|
||||
return torch.clamp(output_lengths, min=0).to(dtype=torch.long)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RefSemanticTask:
|
||||
raw_audio: torch.Tensor
|
||||
raw_sr: int
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
batch_popped_at: float = 0.0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
done_loop: asyncio.AbstractEventLoop | None = None
|
||||
done_future: asyncio.Future | None = None
|
||||
result_prompt_semantic: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareRefSemanticBatchWorker:
|
||||
def __init__(
|
||||
self,
|
||||
ssl_model,
|
||||
vits_model,
|
||||
device,
|
||||
is_half: bool,
|
||||
zero_wav_samples: int,
|
||||
stage_limiter=None,
|
||||
batch_window_ms: int = 5,
|
||||
max_batch_items: int = 8,
|
||||
max_batch_samples: int = 960000,
|
||||
):
|
||||
self.ssl_model = ssl_model
|
||||
self.vits_model = vits_model
|
||||
self.device = device
|
||||
self.is_half = bool(is_half)
|
||||
self.zero_wav_samples = max(0, int(zero_wav_samples))
|
||||
self.stage_limiter = stage_limiter
|
||||
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
|
||||
self.max_batch_items = max(1, int(max_batch_items))
|
||||
self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples))
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[RefSemanticTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.total_batches = 0
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_peak = 0
|
||||
self.active_batch_samples = 0
|
||||
self.active_batch_samples_peak = 0
|
||||
self.worker_thread = threading.Thread(
|
||||
target=self._run_loop,
|
||||
name="prepare-ref-semantic-batch-worker",
|
||||
daemon=True,
|
||||
)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _estimate_task_samples(self, task: RefSemanticTask) -> int:
|
||||
raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0
|
||||
base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr))))
|
||||
return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples
|
||||
|
||||
def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr))
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.result_prompt_semantic is not None
|
||||
return task.result_prompt_semantic, dict(task.profile)
|
||||
|
||||
async def submit_async(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = RefSemanticTask(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=int(raw_sr),
|
||||
done_loop=loop,
|
||||
done_future=loop.create_future(),
|
||||
)
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
return await task.done_future
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(task: RefSemanticTask) -> None:
|
||||
if task.done_future is None or task.done_future.done():
|
||||
return
|
||||
if task.error is not None:
|
||||
task.done_future.set_exception(task.error)
|
||||
return
|
||||
assert task.result_prompt_semantic is not None
|
||||
task.done_future.set_result((task.result_prompt_semantic, dict(task.profile)))
|
||||
|
||||
def _notify_task_done(self, task: RefSemanticTask) -> None:
|
||||
task.done_event.set()
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
"pending": len(self.pending_tasks),
|
||||
"pending_peak": self.pending_peak,
|
||||
"total_submitted": self.total_submitted,
|
||||
"total_finished": self.total_finished,
|
||||
"total_batches": self.total_batches,
|
||||
"active_batch_size": self.active_batch_size,
|
||||
"active_batch_peak": self.active_batch_peak,
|
||||
"active_batch_samples": self.active_batch_samples,
|
||||
"active_batch_samples_peak": self.active_batch_samples_peak,
|
||||
"batch_window_ms": int(self.batch_window_s * 1000.0),
|
||||
"max_batch_items": self.max_batch_items,
|
||||
"max_batch_samples": self.max_batch_samples,
|
||||
}
|
||||
|
||||
def _collect_batch(self) -> tuple[List[RefSemanticTask], float]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
first_task = self.pending_tasks.popleft()
|
||||
first_task.batch_popped_at = time.perf_counter()
|
||||
batch: List[RefSemanticTask] = [first_task]
|
||||
batch_samples = self._estimate_task_samples(batch[0])
|
||||
deadline = time.perf_counter() + self.batch_window_s
|
||||
|
||||
while len(batch) < self.max_batch_items:
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.condition.wait(timeout=remaining)
|
||||
continue
|
||||
next_task = self.pending_tasks[0]
|
||||
next_samples = self._estimate_task_samples(next_task)
|
||||
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
|
||||
break
|
||||
popped_task = self.pending_tasks.popleft()
|
||||
popped_task.batch_popped_at = time.perf_counter()
|
||||
batch.append(popped_task)
|
||||
batch_samples += next_samples
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
self.active_batch_samples = batch_samples
|
||||
if self.active_batch_size > self.active_batch_peak:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
if self.active_batch_samples > self.active_batch_samples_peak:
|
||||
self.active_batch_samples_peak = self.active_batch_samples
|
||||
return batch, time.perf_counter()
|
||||
|
||||
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||
with self.condition:
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_samples = 0
|
||||
self.total_batches += 1
|
||||
self.total_finished += len(batch)
|
||||
|
||||
def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor:
|
||||
model = self.ssl_model.model
|
||||
if hasattr(model, "_get_feature_vector_attention_mask"):
|
||||
feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask)
|
||||
return feature_mask.to(dtype=torch.long).sum(dim=1)
|
||||
raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1)
|
||||
if hasattr(model, "_get_feat_extract_output_lengths"):
|
||||
return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long)
|
||||
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _run_batch(self, batch: List[RefSemanticTask], batch_collected_at: float) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
prepared_start = time.perf_counter()
|
||||
prepared_wavs = [
|
||||
prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch
|
||||
]
|
||||
cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0
|
||||
wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long)
|
||||
batch_samples = int(wav_lengths.sum().item())
|
||||
max_wav_len = int(wav_lengths.max().item())
|
||||
|
||||
pack_start = time.perf_counter()
|
||||
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
|
||||
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
|
||||
for batch_index, wav in enumerate(prepared_wavs):
|
||||
wav_len = int(wav.shape[0])
|
||||
input_values_cpu[batch_index, :wav_len] = wav
|
||||
attention_mask_cpu[batch_index, :wav_len] = 1
|
||||
pack_ms = (time.perf_counter() - pack_start) * 1000.0
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
h2d_ms = 0.0
|
||||
ssl_forward_ms = 0.0
|
||||
hidden_length_ms = 0.0
|
||||
extract_latent_ms = 0.0
|
||||
if self.stage_limiter is None:
|
||||
h2d_start = time.perf_counter()
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
|
||||
ssl_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_length_start = time.perf_counter()
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
|
||||
latent_start = time.perf_counter()
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
|
||||
else:
|
||||
with self.stage_limiter.enter() as limiter_stats:
|
||||
h2d_start = time.perf_counter()
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
h2d_ms = (time.perf_counter() - h2d_start) * 1000.0
|
||||
ssl_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
ssl_forward_ms = (time.perf_counter() - ssl_start) * 1000.0
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_length_start = time.perf_counter()
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
hidden_length_ms = (time.perf_counter() - hidden_length_start) * 1000.0
|
||||
latent_start = time.perf_counter()
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
extract_latent_ms = (time.perf_counter() - latent_start) * 1000.0
|
||||
forward_ms = float(h2d_ms + ssl_forward_ms + hidden_length_ms + extract_latent_ms)
|
||||
|
||||
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
|
||||
scatter_start = time.perf_counter()
|
||||
for batch_index, task in enumerate(batch):
|
||||
try:
|
||||
code_len = int(code_lengths[batch_index].item())
|
||||
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
|
||||
worker_queue_wait_ms = max(0.0, (float(task.batch_popped_at) - float(task.created_at)) * 1000.0)
|
||||
batch_collect_wait_ms = max(0.0, (float(batch_collected_at) - float(task.batch_popped_at)) * 1000.0)
|
||||
stage_limiter_wait_ms = float(limiter_stats["wait_ms"])
|
||||
task.profile = {
|
||||
"prompt_semantic_wait_ms": worker_queue_wait_ms
|
||||
+ batch_collect_wait_ms
|
||||
+ stage_limiter_wait_ms,
|
||||
"prompt_semantic_worker_queue_wait_ms": worker_queue_wait_ms,
|
||||
"prompt_semantic_batch_collect_wait_ms": batch_collect_wait_ms,
|
||||
"prompt_semantic_stage_limiter_wait_ms": stage_limiter_wait_ms,
|
||||
"prompt_semantic_batch_dispatch_delay_ms": max(
|
||||
0.0, (float(batch_started) - float(batch_collected_at)) * 1000.0
|
||||
),
|
||||
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
||||
"prompt_semantic_pack_ms": float(pack_ms),
|
||||
"prompt_semantic_h2d_ms": float(h2d_ms),
|
||||
"prompt_semantic_ssl_forward_ms": float(ssl_forward_ms),
|
||||
"prompt_semantic_hidden_length_ms": float(hidden_length_ms),
|
||||
"prompt_semantic_extract_latent_ms": float(extract_latent_ms),
|
||||
"prompt_semantic_forward_ms": float(forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_calls": 1.0,
|
||||
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
|
||||
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"prompt_semantic_batch_size": float(len(batch)),
|
||||
"prompt_semantic_batch_samples": float(batch_samples),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
|
||||
for task in batch:
|
||||
if task.result_prompt_semantic is not None:
|
||||
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
|
||||
self._notify_task_done(task)
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch, batch_collected_at = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch, batch_collected_at)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
self._notify_task_done(task)
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
215
GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py
Normal file
215
GPT_SoVITS/TTS_infer_pack/prepare_text_cpu_worker.py
Normal file
@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Deque, Dict, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextCpuTask:
|
||||
text: str
|
||||
language: str
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
enqueued_at: float = 0.0
|
||||
admission_wait_ms: float = 0.0
|
||||
backpressure_wait_ms: float = 0.0
|
||||
capacity_wait_ms: float = 0.0
|
||||
pending_depth_on_enqueue: int = 0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
done_loop: asyncio.AbstractEventLoop | None = None
|
||||
done_future: asyncio.Future | None = None
|
||||
result: Any = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareTextCpuWorker:
|
||||
def __init__(
|
||||
self,
|
||||
process_fn: Callable[[str, str], Any],
|
||||
worker_count: int,
|
||||
max_pending_tasks: int = 0,
|
||||
admission_poll_ms: int = 1,
|
||||
admission_controller: Callable[[], Dict[str, float | int | bool]] | None = None,
|
||||
) -> None:
|
||||
self.process_fn = process_fn
|
||||
self.worker_count = max(1, int(worker_count))
|
||||
self.max_pending_tasks = max(0, int(max_pending_tasks))
|
||||
self.admission_poll_s = max(0.0005, float(max(1, int(admission_poll_ms))) / 1000.0)
|
||||
self.admission_controller = admission_controller
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[TextCpuTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.active_workers = 0
|
||||
self.active_workers_peak = 0
|
||||
self.admission_wait_total_ms = 0.0
|
||||
self.admission_wait_peak_ms = 0.0
|
||||
self.backpressure_wait_total_ms = 0.0
|
||||
self.backpressure_wait_peak_ms = 0.0
|
||||
self.capacity_wait_total_ms = 0.0
|
||||
self.capacity_wait_peak_ms = 0.0
|
||||
self.backpressure_blocked_total = 0
|
||||
|
||||
self.worker_threads = [
|
||||
threading.Thread(target=self._run_loop, name=f"prepare-text-cpu-worker-{index}", daemon=True)
|
||||
for index in range(self.worker_count)
|
||||
]
|
||||
for thread in self.worker_threads:
|
||||
thread.start()
|
||||
|
||||
def _can_enqueue_locked(self) -> bool:
|
||||
if self.max_pending_tasks <= 0:
|
||||
return True
|
||||
return (len(self.pending_tasks) + self.active_workers) < self.max_pending_tasks
|
||||
|
||||
def _get_admission_state(self) -> Dict[str, float | int | bool]:
|
||||
if self.admission_controller is None:
|
||||
return {"blocked": False}
|
||||
try:
|
||||
state = dict(self.admission_controller() or {})
|
||||
except Exception:
|
||||
return {"blocked": False}
|
||||
state["blocked"] = bool(state.get("blocked", False))
|
||||
return state
|
||||
|
||||
def _record_enqueue_locked(
|
||||
self,
|
||||
task: TextCpuTask,
|
||||
*,
|
||||
admission_wait_ms: float,
|
||||
backpressure_wait_ms: float,
|
||||
capacity_wait_ms: float,
|
||||
) -> None:
|
||||
task.admission_wait_ms = float(max(0.0, admission_wait_ms))
|
||||
task.backpressure_wait_ms = float(max(0.0, backpressure_wait_ms))
|
||||
task.capacity_wait_ms = float(max(0.0, capacity_wait_ms))
|
||||
task.enqueued_at = time.perf_counter()
|
||||
task.pending_depth_on_enqueue = int(len(self.pending_tasks))
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
self.admission_wait_total_ms += task.admission_wait_ms
|
||||
self.admission_wait_peak_ms = max(self.admission_wait_peak_ms, task.admission_wait_ms)
|
||||
self.backpressure_wait_total_ms += task.backpressure_wait_ms
|
||||
self.backpressure_wait_peak_ms = max(self.backpressure_wait_peak_ms, task.backpressure_wait_ms)
|
||||
self.capacity_wait_total_ms += task.capacity_wait_ms
|
||||
self.capacity_wait_peak_ms = max(self.capacity_wait_peak_ms, task.capacity_wait_ms)
|
||||
if task.backpressure_wait_ms > 0.0:
|
||||
self.backpressure_blocked_total += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
|
||||
async def _enqueue_task_async(self, task: TextCpuTask) -> None:
|
||||
admission_started = time.perf_counter()
|
||||
backpressure_wait_ms = 0.0
|
||||
capacity_wait_ms = 0.0
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
admission_state = self._get_admission_state()
|
||||
blocked = bool(admission_state.get("blocked", False))
|
||||
with self.condition:
|
||||
if not blocked and self._can_enqueue_locked():
|
||||
self._record_enqueue_locked(
|
||||
task,
|
||||
admission_wait_ms=(time.perf_counter() - admission_started) * 1000.0,
|
||||
backpressure_wait_ms=backpressure_wait_ms,
|
||||
capacity_wait_ms=capacity_wait_ms,
|
||||
)
|
||||
return
|
||||
await asyncio.sleep(self.admission_poll_s)
|
||||
waited_ms = (time.perf_counter() - loop_start) * 1000.0
|
||||
if blocked:
|
||||
backpressure_wait_ms += waited_ms
|
||||
else:
|
||||
capacity_wait_ms += waited_ms
|
||||
|
||||
def submit(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
|
||||
task = TextCpuTask(text=str(text), language=str(language))
|
||||
asyncio.run(self._enqueue_task_async(task))
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
return task.result, dict(task.profile)
|
||||
|
||||
async def submit_async(self, text: str, language: str) -> Tuple[Any, Dict[str, float]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = TextCpuTask(
|
||||
text=str(text),
|
||||
language=str(language),
|
||||
done_loop=loop,
|
||||
done_future=loop.create_future(),
|
||||
)
|
||||
await self._enqueue_task_async(task)
|
||||
return await task.done_future
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(task: TextCpuTask) -> None:
|
||||
if task.done_future is None or task.done_future.done():
|
||||
return
|
||||
if task.error is not None:
|
||||
task.done_future.set_exception(task.error)
|
||||
return
|
||||
task.done_future.set_result((task.result, dict(task.profile)))
|
||||
|
||||
def _notify_task_done(self, task: TextCpuTask) -> None:
|
||||
task.done_event.set()
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_done_future, task)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def snapshot(self) -> Dict[str, int | float]:
|
||||
with self.condition:
|
||||
return {
|
||||
"worker_count": int(self.worker_count),
|
||||
"pending": int(len(self.pending_tasks)),
|
||||
"pending_peak": int(self.pending_peak),
|
||||
"active_workers": int(self.active_workers),
|
||||
"active_workers_peak": int(self.active_workers_peak),
|
||||
"total_submitted": int(self.total_submitted),
|
||||
"total_finished": int(self.total_finished),
|
||||
"max_pending_tasks": int(self.max_pending_tasks),
|
||||
"admission_wait_total_ms": float(self.admission_wait_total_ms),
|
||||
"admission_wait_peak_ms": float(self.admission_wait_peak_ms),
|
||||
"backpressure_wait_total_ms": float(self.backpressure_wait_total_ms),
|
||||
"backpressure_wait_peak_ms": float(self.backpressure_wait_peak_ms),
|
||||
"capacity_wait_total_ms": float(self.capacity_wait_total_ms),
|
||||
"capacity_wait_peak_ms": float(self.capacity_wait_peak_ms),
|
||||
"backpressure_blocked_total": int(self.backpressure_blocked_total),
|
||||
}
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
task = self.pending_tasks.popleft()
|
||||
self.active_workers += 1
|
||||
self.active_workers_peak = max(self.active_workers_peak, self.active_workers)
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
task.result = self.process_fn(task.text, task.language)
|
||||
task.profile = {
|
||||
"text_cpu_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"text_cpu_backpressure_wait_ms": float(task.backpressure_wait_ms),
|
||||
"text_cpu_capacity_wait_ms": float(task.capacity_wait_ms),
|
||||
"text_cpu_queue_wait_ms": max(0.0, (started_at - task.enqueued_at) * 1000.0),
|
||||
"text_cpu_pending_depth_on_enqueue": float(task.pending_depth_on_enqueue),
|
||||
"text_cpu_run_ms": max(0.0, (time.perf_counter() - started_at) * 1000.0),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
finally:
|
||||
with self.condition:
|
||||
self.active_workers = max(0, self.active_workers - 1)
|
||||
self.total_finished += 1
|
||||
self.condition.notify_all()
|
||||
self._notify_task_done(task)
|
||||
1285
GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py
Normal file
1285
GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
112
GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py
Normal file
112
GPT_SoVITS/TTS_infer_pack/text_cpu_preprocess.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from text import cleaned_text_to_sequence
|
||||
from text import chinese2
|
||||
from text.cleaner import clean_text
|
||||
|
||||
|
||||
PreparedTextSegmentPayload = Dict[str, object]
|
||||
|
||||
|
||||
def split_text_by_language(text: str, language: str) -> Tuple[List[str], List[str]]:
|
||||
textlist: List[str] = []
|
||||
langlist: List[str] = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||
)
|
||||
if same_group:
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
return textlist, langlist
|
||||
|
||||
|
||||
def clean_text_segment(text: str, language: str, version: str) -> Tuple[List[int], Optional[List[int]], str]:
|
||||
normalized_language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, normalized_language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return list(phones), None if word2ph is None else list(word2ph), str(norm_text)
|
||||
|
||||
|
||||
def preprocess_text_segments_payload(
|
||||
text: str,
|
||||
language: str,
|
||||
version: str,
|
||||
final: bool = False,
|
||||
) -> List[PreparedTextSegmentPayload]:
|
||||
text = re.sub(r" {2,}", " ", text)
|
||||
textlist, langlist = split_text_by_language(text, language)
|
||||
payloads: List[PreparedTextSegmentPayload] = []
|
||||
total_phones_len = 0
|
||||
for segment_text, segment_lang in zip(textlist, langlist):
|
||||
normalized_language = segment_lang.replace("all_", "")
|
||||
if normalized_language == "zh":
|
||||
norm_text = chinese2.text_normalize(segment_text)
|
||||
phones = []
|
||||
word2ph = None
|
||||
needs_g2pw = True
|
||||
estimated_phones_len = max(0, len(norm_text) * 2)
|
||||
else:
|
||||
phones, word2ph, norm_text = clean_text_segment(segment_text, segment_lang, version)
|
||||
needs_g2pw = False
|
||||
estimated_phones_len = len(phones)
|
||||
payloads.append(
|
||||
{
|
||||
"language": normalized_language,
|
||||
"phones": phones,
|
||||
"word2ph": word2ph,
|
||||
"norm_text": norm_text,
|
||||
"needs_g2pw": needs_g2pw,
|
||||
}
|
||||
)
|
||||
total_phones_len += int(estimated_phones_len)
|
||||
|
||||
if not final and total_phones_len < 6:
|
||||
return preprocess_text_segments_payload("." + text, language, version, final=True)
|
||||
|
||||
return payloads
|
||||
46
GPT_SoVITS/TTS_infer_pack/unified_engine.py
Normal file
46
GPT_SoVITS/TTS_infer_pack/unified_engine.py
Normal file
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_builder import EngineCompositionBuilder
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeControlCallbacks
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_delegates import EngineApiDelegates, EngineBridgeDelegates, EngineRuntimeDelegates
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_public import EngineCompatInterface, EnginePublicInterface
|
||||
|
||||
|
||||
class UnifiedTTSEngine(EnginePublicInterface, EngineCompatInterface, EngineBridgeDelegates, EngineApiDelegates, EngineRuntimeDelegates):
|
||||
@staticmethod
|
||||
def _env_flag(name: str, default: bool) -> bool:
|
||||
value = os.environ.get(name)
|
||||
if value is None:
|
||||
return bool(default)
|
||||
return str(value).strip().lower() not in {"0", "false", "no", "off", ""}
|
||||
|
||||
@staticmethod
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
value = os.environ.get(name)
|
||||
if value in [None, ""]:
|
||||
return int(default)
|
||||
return int(value)
|
||||
|
||||
@staticmethod
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
value = os.environ.get(name)
|
||||
if value in [None, ""]:
|
||||
return float(default)
|
||||
return float(value)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tts: TTS,
|
||||
cut_method_names: Sequence[str],
|
||||
control_callbacks: RuntimeControlCallbacks | None = None,
|
||||
max_steps: int = 1500,
|
||||
micro_batch_wait_ms: int = 5,
|
||||
) -> None:
|
||||
self.tts = tts
|
||||
self.cut_method_names = set(cut_method_names)
|
||||
self.control_callbacks = control_callbacks or RuntimeControlCallbacks()
|
||||
EngineCompositionBuilder(self).build(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms)
|
||||
451
GPT_SoVITS/TTS_infer_pack/unified_engine_api.py
Normal file
451
GPT_SoVITS/TTS_infer_pack/unified_engine_api.py
Normal file
@ -0,0 +1,451 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_direct import EngineApiDirectFlow
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_profile import (
|
||||
aggregate_numeric_dicts,
|
||||
build_direct_scheduler_profile,
|
||||
build_direct_segment_trace,
|
||||
build_legacy_direct_profile,
|
||||
build_request_meta,
|
||||
build_scheduler_debug_batch_profile,
|
||||
build_scheduler_debug_request_profile,
|
||||
build_scheduler_submit_headers,
|
||||
build_scheduler_submit_profile,
|
||||
format_ms_header,
|
||||
sum_profile_field,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_request import (
|
||||
apply_default_reference,
|
||||
base_request_defaults,
|
||||
check_params,
|
||||
is_aux_ref_enabled,
|
||||
normalize_engine_request,
|
||||
normalize_lang,
|
||||
normalize_streaming_mode,
|
||||
select_direct_backend,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_scheduler import EngineApiSchedulerFlow
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
|
||||
DirectTTSExecution,
|
||||
NormalizedEngineRequest,
|
||||
SchedulerDebugExecution,
|
||||
SchedulerSubmitExecution,
|
||||
)
|
||||
|
||||
|
||||
class EngineApiFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
self.direct_flow = EngineApiDirectFlow(self)
|
||||
self.scheduler_flow = EngineApiSchedulerFlow(self)
|
||||
|
||||
@property
|
||||
def tts(self):
|
||||
return self.owner.tts
|
||||
|
||||
@property
|
||||
def cut_method_names(self):
|
||||
return self.owner.cut_method_names
|
||||
|
||||
@property
|
||||
def reference_registry(self):
|
||||
return self.owner.reference_registry
|
||||
|
||||
@property
|
||||
def direct_tts_lock(self):
|
||||
return self.owner.direct_tts_lock
|
||||
|
||||
@property
|
||||
def scheduler_worker(self):
|
||||
return self.owner.scheduler_worker
|
||||
|
||||
def _register_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
api_mode: str,
|
||||
backend: str,
|
||||
media_type: str,
|
||||
response_streaming: bool,
|
||||
deadline_ts: float | None = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
return self.owner._register_request_state(
|
||||
request_id=request_id,
|
||||
api_mode=api_mode,
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=response_streaming,
|
||||
deadline_ts=deadline_ts,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def _update_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
status: str,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.owner._update_request_state(request_id, status, extra)
|
||||
|
||||
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.owner._merge_request_state_profile(request_id, extra)
|
||||
|
||||
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.owner._complete_request_state(request_id, extra)
|
||||
|
||||
def _fail_request_state(self, request_id: str, error: str) -> None:
|
||||
self.owner._fail_request_state(request_id, error)
|
||||
|
||||
async def _prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.owner._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=prepare_submit_at,
|
||||
engine_request_id=engine_request_id,
|
||||
)
|
||||
|
||||
async def _enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
):
|
||||
return await self.owner._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
|
||||
return self.owner.request_registry.collect_summaries(request_ids)
|
||||
|
||||
def _has_active_request(self, request_id: str) -> bool:
|
||||
return self.owner.request_registry.has_active(request_id)
|
||||
|
||||
@staticmethod
|
||||
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return build_request_meta(payload)
|
||||
|
||||
@staticmethod
|
||||
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
|
||||
return sum_profile_field(items, key)
|
||||
|
||||
def _build_direct_segment_trace(
|
||||
self,
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
return build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
|
||||
|
||||
def _build_direct_scheduler_profile(
|
||||
self,
|
||||
*,
|
||||
backend: str,
|
||||
request_start: float,
|
||||
response_ready_at: float,
|
||||
audio_bytes: int,
|
||||
sample_rate: int,
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
pack_ms: float,
|
||||
response_overhead_ms: float,
|
||||
) -> Dict[str, Any]:
|
||||
return build_direct_scheduler_profile(
|
||||
backend=backend,
|
||||
request_start=request_start,
|
||||
response_ready_at=response_ready_at,
|
||||
audio_bytes=audio_bytes,
|
||||
sample_rate=sample_rate,
|
||||
segment_texts=segment_texts,
|
||||
prepare_profiles=prepare_profiles,
|
||||
worker_profiles=worker_profiles,
|
||||
pack_ms=pack_ms,
|
||||
response_overhead_ms=response_overhead_ms,
|
||||
)
|
||||
|
||||
def _build_legacy_direct_profile(
|
||||
self,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
request_start: float,
|
||||
finished_at: float,
|
||||
sample_rate: int | None = None,
|
||||
audio_bytes: int = 0,
|
||||
pack_ms: float = 0.0,
|
||||
chunk_count: int = 0,
|
||||
stream_total_bytes: int = 0,
|
||||
first_chunk_ms: float | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
return build_legacy_direct_profile(
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
request_start=request_start,
|
||||
finished_at=finished_at,
|
||||
sample_rate=sample_rate,
|
||||
audio_bytes=audio_bytes,
|
||||
pack_ms=pack_ms,
|
||||
chunk_count=chunk_count,
|
||||
stream_total_bytes=stream_total_bytes,
|
||||
first_chunk_ms=first_chunk_ms,
|
||||
)
|
||||
|
||||
def _build_scheduler_submit_profile(
|
||||
self,
|
||||
*,
|
||||
backend: str,
|
||||
request_start: float,
|
||||
response_ready_at: float,
|
||||
audio_bytes: int,
|
||||
sample_rate: int,
|
||||
prepare_spec_build_ms: float,
|
||||
prepare_wall_ms: float,
|
||||
prepare_executor_queue_ms: float,
|
||||
prepare_executor_run_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
prepare_profile_wall_ms: float,
|
||||
prepare_other_ms: float,
|
||||
engine_policy_wait_ms: float,
|
||||
api_after_prepare_ms: float,
|
||||
api_wait_result_ms: float,
|
||||
pack_ms: float,
|
||||
response_overhead_ms: float,
|
||||
worker_profile: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return build_scheduler_submit_profile(
|
||||
backend=backend,
|
||||
request_start=request_start,
|
||||
response_ready_at=response_ready_at,
|
||||
audio_bytes=audio_bytes,
|
||||
sample_rate=sample_rate,
|
||||
prepare_spec_build_ms=prepare_spec_build_ms,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_executor_queue_ms=prepare_executor_queue_ms,
|
||||
prepare_executor_run_ms=prepare_executor_run_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
prepare_profile_wall_ms=prepare_profile_wall_ms,
|
||||
prepare_other_ms=prepare_other_ms,
|
||||
engine_policy_wait_ms=engine_policy_wait_ms,
|
||||
api_after_prepare_ms=api_after_prepare_ms,
|
||||
api_wait_result_ms=api_wait_result_ms,
|
||||
pack_ms=pack_ms,
|
||||
response_overhead_ms=response_overhead_ms,
|
||||
worker_profile=worker_profile,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_ms_header(value: Any) -> str:
|
||||
return format_ms_header(value)
|
||||
|
||||
def _build_scheduler_submit_headers(
|
||||
self,
|
||||
*,
|
||||
request_id: str,
|
||||
media_type: str,
|
||||
sample_rate: int,
|
||||
profile: Dict[str, Any],
|
||||
) -> Dict[str, str]:
|
||||
return build_scheduler_submit_headers(
|
||||
request_id=request_id,
|
||||
media_type=media_type,
|
||||
sample_rate=sample_rate,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
def _build_scheduler_debug_request_profile(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
item: T2SFinishedItem,
|
||||
batch_request_count: int,
|
||||
prepare_batch_wall_ms: float,
|
||||
decode_batch_wall_ms: float,
|
||||
batch_request_total_ms: float,
|
||||
) -> Dict[str, Any]:
|
||||
return build_scheduler_debug_request_profile(
|
||||
state=state,
|
||||
item=item,
|
||||
batch_request_count=batch_request_count,
|
||||
prepare_batch_wall_ms=prepare_batch_wall_ms,
|
||||
decode_batch_wall_ms=decode_batch_wall_ms,
|
||||
batch_request_total_ms=batch_request_total_ms,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scheduler_debug_batch_profile(
|
||||
*,
|
||||
request_count: int,
|
||||
max_steps: int,
|
||||
prepare_batch_wall_ms: float,
|
||||
decode_batch_wall_ms: float,
|
||||
request_total_ms: float,
|
||||
finished_items: Sequence[T2SFinishedItem],
|
||||
) -> Dict[str, Any]:
|
||||
return build_scheduler_debug_batch_profile(
|
||||
request_count=request_count,
|
||||
max_steps=max_steps,
|
||||
prepare_batch_wall_ms=prepare_batch_wall_ms,
|
||||
decode_batch_wall_ms=decode_batch_wall_ms,
|
||||
request_total_ms=request_total_ms,
|
||||
finished_items=finished_items,
|
||||
)
|
||||
|
||||
def _normalize_lang(self, value: str | None) -> str | None:
|
||||
return normalize_lang(value)
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
|
||||
return aggregate_numeric_dicts(items)
|
||||
|
||||
def _apply_default_reference(self, req: dict) -> dict:
|
||||
return apply_default_reference(self.reference_registry, req)
|
||||
|
||||
def check_params(self, req: dict) -> Optional[str]:
|
||||
return check_params(self.tts, self.cut_method_names, req)
|
||||
|
||||
@staticmethod
|
||||
def _base_request_defaults() -> Dict[str, Any]:
|
||||
return base_request_defaults()
|
||||
|
||||
def _normalize_engine_request(
|
||||
self,
|
||||
payload: dict | NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str | None = None,
|
||||
normalize_streaming: bool = False,
|
||||
error_prefix: str = "request 参数非法: ",
|
||||
) -> NormalizedEngineRequest:
|
||||
return normalize_engine_request(
|
||||
tts=self.tts,
|
||||
cut_method_names=self.cut_method_names,
|
||||
reference_registry=self.reference_registry,
|
||||
payload=payload,
|
||||
request_id=request_id,
|
||||
normalize_streaming=normalize_streaming,
|
||||
error_prefix=error_prefix,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_streaming_mode(req: dict) -> dict:
|
||||
return normalize_streaming_mode(req)
|
||||
|
||||
@staticmethod
|
||||
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
|
||||
return is_aux_ref_enabled(aux_ref_audio_paths)
|
||||
|
||||
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
|
||||
return select_direct_backend(normalized)
|
||||
|
||||
def _iter_legacy_direct_tts_bytes(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
yield from self.direct_flow._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
|
||||
return self.direct_flow._should_use_scheduler_backend_for_direct(req)
|
||||
|
||||
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
|
||||
return self.direct_flow._segment_direct_text(normalized)
|
||||
|
||||
def _build_segment_request(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str,
|
||||
text: str,
|
||||
) -> NormalizedEngineRequest:
|
||||
return self.direct_flow._build_segment_request(
|
||||
normalized,
|
||||
request_id=request_id,
|
||||
text=text,
|
||||
)
|
||||
|
||||
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
|
||||
return await self.direct_flow._run_direct_tts_via_scheduler(normalized)
|
||||
|
||||
def _run_legacy_direct_tts_blocking(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
return self.direct_flow._run_legacy_direct_tts_blocking(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def _run_direct_tts_via_legacy_backend(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
return await self.direct_flow._run_direct_tts_via_legacy_backend(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
|
||||
return await self.direct_flow.run_direct_tts_async(req)
|
||||
|
||||
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
|
||||
return self.direct_flow.run_direct_tts(req)
|
||||
|
||||
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
||||
return self.scheduler_flow._build_scheduler_request_specs(request_items)
|
||||
|
||||
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
||||
return self.scheduler_flow._build_scheduler_submit_spec(payload)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
||||
return EngineApiSchedulerFlow._summarize_scheduler_states(states)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
return EngineApiSchedulerFlow._summarize_scheduler_finished(items)
|
||||
|
||||
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
return await self.scheduler_flow.run_scheduler_debug(request_items, max_steps, seed)
|
||||
|
||||
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
||||
return await self.scheduler_flow.run_scheduler_submit(payload)
|
||||
165
GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py
Normal file
165
GPT_SoVITS/TTS_infer_pack/unified_engine_api_delegates.py
Normal file
@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, NormalizedEngineRequest
|
||||
|
||||
|
||||
class EngineApiDelegates:
|
||||
def _collect_request_summaries(self, request_ids: Sequence[str]) -> List[Dict[str, Any]]:
|
||||
return self.api_facade._collect_request_summaries(request_ids)
|
||||
|
||||
def _has_active_request(self, request_id: str) -> bool:
|
||||
return self.api_facade._has_active_request(request_id)
|
||||
|
||||
@staticmethod
|
||||
def _build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return EngineApiFacade._build_request_meta(payload)
|
||||
|
||||
@staticmethod
|
||||
def _sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
|
||||
return EngineApiFacade._sum_profile_field(items, key)
|
||||
|
||||
def _build_direct_segment_trace(
|
||||
self,
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.api_facade._build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
|
||||
|
||||
def _build_direct_scheduler_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_direct_scheduler_profile(**kwargs)
|
||||
|
||||
def _build_legacy_direct_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_legacy_direct_profile(**kwargs)
|
||||
|
||||
def _build_scheduler_submit_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_scheduler_submit_profile(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _format_ms_header(value: Any) -> str:
|
||||
return EngineApiFacade._format_ms_header(value)
|
||||
|
||||
def _build_scheduler_submit_headers(
|
||||
self,
|
||||
*,
|
||||
request_id: str,
|
||||
media_type: str,
|
||||
sample_rate: int,
|
||||
profile: Dict[str, Any],
|
||||
) -> Dict[str, str]:
|
||||
return self.api_facade._build_scheduler_submit_headers(
|
||||
request_id=request_id,
|
||||
media_type=media_type,
|
||||
sample_rate=sample_rate,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
def _build_scheduler_debug_request_profile(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
return self.api_facade._build_scheduler_debug_request_profile(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _build_scheduler_debug_batch_profile(**kwargs: Any) -> Dict[str, Any]:
|
||||
return EngineApiFacade._build_scheduler_debug_batch_profile(**kwargs)
|
||||
|
||||
def _normalize_lang(self, value: str | None) -> str | None:
|
||||
return self.api_facade._normalize_lang(value)
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
|
||||
return EngineApiFacade._aggregate_numeric_dicts(items)
|
||||
|
||||
def _apply_default_reference(self, req: dict) -> dict:
|
||||
return self.api_facade._apply_default_reference(req)
|
||||
|
||||
def check_params(self, req: dict) -> Optional[str]:
|
||||
return self.api_facade.check_params(req)
|
||||
|
||||
@staticmethod
|
||||
def _base_request_defaults() -> Dict[str, Any]:
|
||||
return EngineApiFacade._base_request_defaults()
|
||||
|
||||
def _normalize_engine_request(
|
||||
self,
|
||||
payload: dict | NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str | None = None,
|
||||
normalize_streaming: bool = False,
|
||||
error_prefix: str = "request 参数非法: ",
|
||||
) -> NormalizedEngineRequest:
|
||||
return self.api_facade._normalize_engine_request(
|
||||
payload,
|
||||
request_id=request_id,
|
||||
normalize_streaming=normalize_streaming,
|
||||
error_prefix=error_prefix,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_streaming_mode(req: dict) -> dict:
|
||||
return EngineApiFacade._normalize_streaming_mode(req)
|
||||
|
||||
@staticmethod
|
||||
def _is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
|
||||
return EngineApiFacade._is_aux_ref_enabled(aux_ref_audio_paths)
|
||||
|
||||
def _select_direct_backend(self, normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
|
||||
return self.api_facade._select_direct_backend(normalized)
|
||||
|
||||
def _iter_legacy_direct_tts_bytes(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
return self.api_facade._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
|
||||
return self.api_facade._should_use_scheduler_backend_for_direct(req)
|
||||
|
||||
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
|
||||
return self.api_facade._segment_direct_text(normalized)
|
||||
|
||||
def _build_segment_request(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str,
|
||||
text: str,
|
||||
) -> NormalizedEngineRequest:
|
||||
return self.api_facade._build_segment_request(normalized, request_id=request_id, text=text)
|
||||
|
||||
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
|
||||
return await self.api_facade._run_direct_tts_via_scheduler(normalized)
|
||||
|
||||
def _run_legacy_direct_tts_blocking(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
return self.api_facade._run_legacy_direct_tts_blocking(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def _run_direct_tts_via_legacy_backend(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
return await self.api_facade._run_direct_tts_via_legacy_backend(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
595
GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py
Normal file
595
GPT_SoVITS/TTS_infer_pack/unified_engine_api_direct.py
Normal file
@ -0,0 +1,595 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, wave_header_chunk
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, EngineStatus, NormalizedEngineRequest, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineApiDirectFlow:
|
||||
def __init__(self, api: Any) -> None:
|
||||
self.api = api
|
||||
|
||||
def _iter_legacy_direct_tts_bytes(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
payload = normalized.to_payload()
|
||||
media_type = normalized.media_type
|
||||
request_id = normalized.request_id
|
||||
request_start = time.perf_counter()
|
||||
chunk_count = 0
|
||||
stream_total_bytes = 0
|
||||
first_chunk_ms: float | None = None
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
|
||||
)
|
||||
try:
|
||||
with self.api.direct_tts_lock:
|
||||
tts_generator = self.api.tts.run(payload)
|
||||
first_chunk = True
|
||||
current_media_type = media_type
|
||||
for sr, chunk in tts_generator:
|
||||
if first_chunk:
|
||||
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.STREAMING,
|
||||
{
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
"sample_rate": int(sr),
|
||||
},
|
||||
)
|
||||
if first_chunk and media_type == "wav":
|
||||
header = wave_header_chunk(sample_rate=sr)
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(header)
|
||||
yield header
|
||||
current_media_type = "raw"
|
||||
first_chunk = False
|
||||
elif first_chunk:
|
||||
first_chunk = False
|
||||
packed_chunk = pack_audio(BytesIO(), chunk, sr, current_media_type).getvalue()
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(packed_chunk)
|
||||
yield packed_chunk
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(
|
||||
self.api._build_legacy_direct_profile(
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
request_start=request_start,
|
||||
finished_at=time.perf_counter(),
|
||||
audio_bytes=stream_total_bytes,
|
||||
chunk_count=chunk_count,
|
||||
stream_total_bytes=stream_total_bytes,
|
||||
first_chunk_ms=first_chunk_ms,
|
||||
),
|
||||
streaming_completed=True,
|
||||
),
|
||||
)
|
||||
|
||||
def _should_use_scheduler_backend_for_direct(self, req: dict | NormalizedEngineRequest) -> bool:
|
||||
if isinstance(req, NormalizedEngineRequest):
|
||||
normalized = req
|
||||
else:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
req,
|
||||
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
|
||||
normalize_streaming=True,
|
||||
)
|
||||
backend, _ = self.api._select_direct_backend(normalized)
|
||||
return backend == "scheduler_v1_direct"
|
||||
|
||||
def _segment_direct_text(self, normalized: dict | NormalizedEngineRequest) -> List[str]:
|
||||
payload = normalized.to_payload() if isinstance(normalized, NormalizedEngineRequest) else normalized
|
||||
return self.api.tts.text_preprocessor.pre_seg_text(
|
||||
str(payload["text"]),
|
||||
str(payload["text_lang"]),
|
||||
str(payload.get("text_split_method", "cut5")),
|
||||
)
|
||||
|
||||
def _build_segment_request(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
request_id: str,
|
||||
text: str,
|
||||
) -> NormalizedEngineRequest:
|
||||
payload = normalized.to_payload()
|
||||
payload["request_id"] = request_id
|
||||
payload["text"] = text
|
||||
payload["streaming_mode"] = False
|
||||
payload["return_fragment"] = False
|
||||
payload["fixed_length_chunk"] = False
|
||||
payload["response_streaming"] = False
|
||||
return self.api._normalize_engine_request(payload, error_prefix="segment request 参数非法: ")
|
||||
|
||||
async def _execute_single_segment_scheduler_job(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
segment_request: NormalizedEngineRequest,
|
||||
) -> tuple[SchedulerPendingJob, Dict[str, Any]]:
|
||||
spec = self.api._build_scheduler_submit_spec(segment_request)
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=time.perf_counter(),
|
||||
engine_request_id=None,
|
||||
)
|
||||
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
|
||||
prepare_profile_total_ms = float(state.prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=normalized.media_type,
|
||||
super_sampling=bool(normalized.super_sampling),
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=None,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
|
||||
job: SchedulerPendingJob = await asyncio.wait_for(done_future, timeout=timeout_sec)
|
||||
return job, {
|
||||
"request_id": spec.request_id,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
}
|
||||
|
||||
def _iter_scheduler_direct_tts_bytes(self, normalized: NormalizedEngineRequest) -> Generator[bytes, None, None]:
|
||||
request_start = time.perf_counter()
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
segment_texts = self._segment_direct_text(normalized)
|
||||
if not segment_texts:
|
||||
raise ValueError("text preprocessing returned no valid segments")
|
||||
chunk_queue: queue.Queue[object] = queue.Queue(maxsize=8)
|
||||
done_marker = object()
|
||||
|
||||
async def _produce_chunks() -> None:
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.CPU_PREPARING,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
|
||||
)
|
||||
sample_rate: int | None = None
|
||||
current_media_type = media_type
|
||||
chunk_count = 0
|
||||
stream_total_bytes = 0
|
||||
first_chunk_ms: float | None = None
|
||||
prepare_profiles: List[Dict[str, Any]] = []
|
||||
worker_profiles: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for segment_index, segment_text in enumerate(segment_texts):
|
||||
segment_request = self._build_segment_request(
|
||||
normalized,
|
||||
request_id=f"{request_id}_seg_{segment_index:03d}",
|
||||
text=segment_text,
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{
|
||||
"backend": "scheduler_v1_direct",
|
||||
"backend_mode": "scheduler_v1_direct",
|
||||
"segment_index": segment_index,
|
||||
"segment_count": len(segment_texts),
|
||||
},
|
||||
)
|
||||
job, prepare_profile = await self._execute_single_segment_scheduler_job(
|
||||
normalized,
|
||||
segment_request=segment_request,
|
||||
)
|
||||
prepare_profiles.append(prepare_profile)
|
||||
if job.error is not None:
|
||||
raise RuntimeError(job.error)
|
||||
if job.audio_data is None or job.sample_rate is None or job.result is None:
|
||||
raise RuntimeError(f"{job.request_id} finished without audio result")
|
||||
worker_profiles.append(dict(job.result))
|
||||
if sample_rate is None:
|
||||
sample_rate = int(job.sample_rate)
|
||||
first_chunk_ms = max(0.0, (time.perf_counter() - request_start) * 1000.0)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.STREAMING,
|
||||
{
|
||||
"backend": "scheduler_v1_direct",
|
||||
"backend_mode": "scheduler_v1_direct",
|
||||
"sample_rate": int(sample_rate),
|
||||
},
|
||||
)
|
||||
if media_type == "wav":
|
||||
header = wave_header_chunk(sample_rate=int(sample_rate))
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(header)
|
||||
chunk_queue.put(header)
|
||||
current_media_type = "raw"
|
||||
packed_chunk = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), current_media_type).getvalue()
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(packed_chunk)
|
||||
chunk_queue.put(packed_chunk)
|
||||
if segment_index + 1 < len(segment_texts):
|
||||
silence_samples = int(float(normalized.fragment_interval) * float(job.sample_rate))
|
||||
if silence_samples > 0:
|
||||
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
|
||||
packed_silence = pack_audio(
|
||||
BytesIO(), silence_chunk, int(job.sample_rate), current_media_type
|
||||
).getvalue()
|
||||
chunk_count += 1
|
||||
stream_total_bytes += len(packed_silence)
|
||||
chunk_queue.put(packed_silence)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
chunk_queue.put(exc)
|
||||
else:
|
||||
self.api._merge_request_state_profile(
|
||||
request_id,
|
||||
{
|
||||
"prepare_aggregate": self.api._aggregate_numeric_dicts(
|
||||
[item["prepare_profile"] for item in prepare_profiles]
|
||||
),
|
||||
"engine_policy_wait_ms": sum(
|
||||
float(item.get("engine_policy_wait_ms", 0.0)) for item in worker_profiles
|
||||
),
|
||||
"engine_dispatch_wait_ms": sum(
|
||||
float(item.get("engine_dispatch_wait_ms", 0.0)) for item in worker_profiles
|
||||
),
|
||||
},
|
||||
)
|
||||
direct_profile = self.api._build_direct_scheduler_profile(
|
||||
backend="scheduler_v1_direct",
|
||||
request_start=request_start,
|
||||
response_ready_at=time.perf_counter(),
|
||||
audio_bytes=stream_total_bytes,
|
||||
sample_rate=int(sample_rate or 0),
|
||||
segment_texts=segment_texts,
|
||||
prepare_profiles=prepare_profiles,
|
||||
worker_profiles=worker_profiles,
|
||||
pack_ms=0.0,
|
||||
response_overhead_ms=0.0,
|
||||
)
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(direct_profile, streaming_completed=True, first_chunk_ms=first_chunk_ms),
|
||||
)
|
||||
finally:
|
||||
chunk_queue.put(done_marker)
|
||||
|
||||
producer_thread = threading.Thread(target=lambda: asyncio.run(_produce_chunks()), daemon=True)
|
||||
producer_thread.start()
|
||||
while True:
|
||||
item = chunk_queue.get()
|
||||
if item is done_marker:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
async def _run_direct_tts_via_scheduler(self, normalized: NormalizedEngineRequest) -> DirectTTSExecution:
|
||||
request_start = time.perf_counter()
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
segment_texts = self._segment_direct_text(normalized)
|
||||
if not segment_texts:
|
||||
raise ValueError("text preprocessing returned no valid segments")
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_scheduler_direct_tts_bytes(normalized),
|
||||
request_id=request_id,
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.CPU_PREPARING,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_texts)},
|
||||
)
|
||||
segment_requests = [
|
||||
self._build_segment_request(
|
||||
normalized,
|
||||
request_id=f"{request_id}_seg_{segment_index:03d}",
|
||||
text=segment_text,
|
||||
)
|
||||
for segment_index, segment_text in enumerate(segment_texts)
|
||||
]
|
||||
prepare_profiles: List[Dict[str, Any]] = []
|
||||
loop = asyncio.get_running_loop()
|
||||
done_futures: List[asyncio.Future] = []
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct", "segment_count": len(segment_requests)},
|
||||
)
|
||||
prepared_items = await asyncio.gather(
|
||||
*[
|
||||
self._execute_single_segment_scheduler_job(
|
||||
normalized,
|
||||
segment_request=segment_request,
|
||||
)
|
||||
for segment_request in segment_requests
|
||||
]
|
||||
)
|
||||
for job, prepare_profile in prepared_items:
|
||||
prepare_profiles.append(prepare_profile)
|
||||
done_future = loop.create_future()
|
||||
done_future.set_result(job)
|
||||
done_futures.append(done_future)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
|
||||
)
|
||||
timeout_sec = float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0)
|
||||
jobs: List[SchedulerPendingJob] = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=timeout_sec))
|
||||
for profile_item, job in zip(prepare_profiles, jobs):
|
||||
profile_item["engine_policy_wait_ms"] = float(job.engine_policy_wait_ms)
|
||||
profile_item["engine_dispatch_wait_ms"] = float(job.engine_dispatch_wait_ms)
|
||||
self.api._merge_request_state_profile(
|
||||
request_id,
|
||||
{
|
||||
"engine_policy_wait_ms": sum(float(job.engine_policy_wait_ms) for job in jobs),
|
||||
"engine_dispatch_wait_ms": sum(float(job.engine_dispatch_wait_ms) for job in jobs),
|
||||
"prepare_aggregate": self.api._aggregate_numeric_dicts([item["prepare_profile"] for item in prepare_profiles]),
|
||||
},
|
||||
)
|
||||
|
||||
sample_rate: int | None = None
|
||||
audio_parts: List[np.ndarray] = []
|
||||
worker_profiles: List[Dict[str, Any]] = []
|
||||
fragment_interval = float(normalized.fragment_interval)
|
||||
silence_chunk: Optional[np.ndarray] = None
|
||||
for job in jobs:
|
||||
if job.error is not None:
|
||||
raise RuntimeError(job.error)
|
||||
if job.audio_data is None or job.sample_rate is None or job.result is None:
|
||||
raise RuntimeError(f"{job.request_id} finished without audio result")
|
||||
if sample_rate is None:
|
||||
sample_rate = int(job.sample_rate)
|
||||
silence_samples = int(fragment_interval * float(sample_rate))
|
||||
if silence_samples > 0:
|
||||
silence_chunk = np.zeros(silence_samples, dtype=np.int16)
|
||||
elif int(job.sample_rate) != sample_rate:
|
||||
raise RuntimeError("segment sample rate mismatch")
|
||||
audio_parts.append(job.audio_data)
|
||||
if silence_chunk is not None:
|
||||
audio_parts.append(silence_chunk.copy())
|
||||
worker_profiles.append(dict(job.result))
|
||||
if sample_rate is None or not audio_parts:
|
||||
raise RuntimeError("direct scheduler backend produced no audio")
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{"backend": "scheduler_v1_direct", "backend_mode": "scheduler_v1_direct"},
|
||||
)
|
||||
merged_audio = np.concatenate(audio_parts, axis=0)
|
||||
pack_start = time.perf_counter()
|
||||
audio_bytes = pack_audio(BytesIO(), merged_audio, sample_rate, media_type).getvalue()
|
||||
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
|
||||
direct_profile = self.api._build_direct_scheduler_profile(
|
||||
backend="scheduler_v1_direct",
|
||||
request_start=request_start,
|
||||
response_ready_at=time.perf_counter(),
|
||||
audio_bytes=len(audio_bytes),
|
||||
sample_rate=int(sample_rate),
|
||||
segment_texts=segment_texts,
|
||||
prepare_profiles=prepare_profiles,
|
||||
worker_profiles=worker_profiles,
|
||||
pack_ms=pack_ms,
|
||||
response_overhead_ms=0.0,
|
||||
)
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(direct_profile, streaming_completed=False),
|
||||
)
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=False,
|
||||
audio_bytes=audio_bytes,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _run_legacy_direct_tts_blocking(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
normalized_payload = normalized.to_payload()
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
request_start = time.perf_counter()
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.ACTIVE_DECODE,
|
||||
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
|
||||
)
|
||||
with self.api.direct_tts_lock:
|
||||
tts_generator = self.api.tts.run(normalized_payload)
|
||||
try:
|
||||
sr, audio_data = next(tts_generator)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{"backend": backend, "backend_mode": backend, "fallback_reason": fallback_reason},
|
||||
)
|
||||
pack_start = time.perf_counter()
|
||||
packed_audio = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
pack_ms = max(0.0, (time.perf_counter() - pack_start) * 1000.0)
|
||||
self.api._complete_request_state(
|
||||
request_id,
|
||||
dict(
|
||||
self.api._build_legacy_direct_profile(
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
request_start=request_start,
|
||||
finished_at=time.perf_counter(),
|
||||
sample_rate=int(sr),
|
||||
audio_bytes=len(packed_audio),
|
||||
pack_ms=pack_ms,
|
||||
),
|
||||
streaming_completed=False,
|
||||
),
|
||||
)
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=False,
|
||||
audio_bytes=packed_audio,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
async def _run_direct_tts_via_legacy_backend(
|
||||
self,
|
||||
normalized: NormalizedEngineRequest,
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
) -> DirectTTSExecution:
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=normalized.media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
),
|
||||
request_id=normalized.request_id,
|
||||
)
|
||||
return await asyncio.to_thread(
|
||||
self._run_legacy_direct_tts_blocking,
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
req,
|
||||
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
|
||||
normalize_streaming=True,
|
||||
error_prefix="",
|
||||
)
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
backend, fallback_reason = self.api._select_direct_backend(normalized)
|
||||
self.api._register_request_state(
|
||||
request_id=request_id,
|
||||
api_mode="tts",
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=bool(normalized.response_streaming),
|
||||
deadline_ts=(time.perf_counter() + float(normalized.timeout_sec) if normalized.timeout_sec is not None else None),
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.VALIDATED,
|
||||
{
|
||||
"request_source": "direct_tts",
|
||||
"selected_backend": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
},
|
||||
)
|
||||
if backend == "scheduler_v1_direct":
|
||||
try:
|
||||
return await self._run_direct_tts_via_scheduler(normalized)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
return await self._run_direct_tts_via_legacy_backend(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
|
||||
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
req,
|
||||
request_id=str(req.get("request_id") or f"direct_{uuid.uuid4().hex[:12]}"),
|
||||
normalize_streaming=True,
|
||||
error_prefix="",
|
||||
)
|
||||
request_id = normalized.request_id
|
||||
media_type = normalized.media_type
|
||||
backend, fallback_reason = self.api._select_direct_backend(normalized)
|
||||
if not self.api._has_active_request(request_id):
|
||||
self.api._register_request_state(
|
||||
request_id=request_id,
|
||||
api_mode="tts",
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=bool(normalized.response_streaming),
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(
|
||||
request_id,
|
||||
EngineStatus.VALIDATED,
|
||||
{
|
||||
"request_source": "direct_tts",
|
||||
"selected_backend": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
},
|
||||
)
|
||||
if backend != "scheduler_v1_direct":
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
return self._run_legacy_direct_tts_blocking(
|
||||
normalized,
|
||||
backend=backend,
|
||||
fallback_reason=fallback_reason,
|
||||
)
|
||||
if normalized.response_streaming:
|
||||
return DirectTTSExecution(
|
||||
media_type=media_type,
|
||||
streaming=True,
|
||||
audio_generator=self._iter_legacy_direct_tts_bytes(
|
||||
normalized,
|
||||
backend="legacy_direct_sync_compat",
|
||||
fallback_reason="sync_direct_compat",
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
return self._run_legacy_direct_tts_blocking(
|
||||
normalized,
|
||||
backend="legacy_direct_sync_compat",
|
||||
fallback_reason="sync_direct_compat",
|
||||
)
|
||||
388
GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py
Normal file
388
GPT_SoVITS/TTS_infer_pack/unified_engine_api_profile.py
Normal file
@ -0,0 +1,388 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
|
||||
|
||||
|
||||
def build_request_meta(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
text = payload.get("text")
|
||||
prompt_text = payload.get("prompt_text")
|
||||
return {
|
||||
"text_len": 0 if text is None else len(str(text)),
|
||||
"prompt_text_len": 0 if prompt_text is None else len(str(prompt_text)),
|
||||
"text_lang": payload.get("text_lang"),
|
||||
"prompt_lang": payload.get("prompt_lang"),
|
||||
"ref_audio_path": payload.get("ref_audio_path"),
|
||||
}
|
||||
|
||||
|
||||
def sum_profile_field(items: Sequence[Dict[str, Any]], key: str) -> float:
|
||||
total = 0.0
|
||||
for item in items:
|
||||
value = item.get(key, 0.0)
|
||||
if isinstance(value, (int, float)):
|
||||
total += float(value)
|
||||
return total
|
||||
|
||||
|
||||
def aggregate_numeric_dicts(items: Sequence[Dict[str, Any]]) -> Dict[str, float]:
|
||||
totals: Dict[str, float] = {}
|
||||
for item in items:
|
||||
for key, value in item.items():
|
||||
if isinstance(value, (int, float)):
|
||||
totals[key] = totals.get(key, 0.0) + float(value)
|
||||
return totals
|
||||
|
||||
|
||||
def build_direct_segment_trace(
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
results: List[Dict[str, Any]] = []
|
||||
for index, segment_text in enumerate(segment_texts):
|
||||
prepare_item = prepare_profiles[index] if index < len(prepare_profiles) else {}
|
||||
worker_item = worker_profiles[index] if index < len(worker_profiles) else {}
|
||||
prepare_profile = dict(prepare_item.get("prepare_profile", {}))
|
||||
results.append(
|
||||
{
|
||||
"segment_index": index,
|
||||
"request_id": prepare_item.get("request_id") or worker_item.get("request_id"),
|
||||
"text_len": len(str(segment_text)),
|
||||
"prepare_wall_ms": float(prepare_item.get("prepare_wall_ms", 0.0)),
|
||||
"prepare_profile_total_ms": float(prepare_item.get("prepare_profile_total_ms", 0.0)),
|
||||
"prepare_engine_gpu_queue_wait_ms": float(
|
||||
dict(prepare_item.get("prepare_profile", {})).get("engine_gpu_prepare_queue_wait_ms", 0.0)
|
||||
),
|
||||
"engine_policy_wait_ms": float(prepare_item.get("engine_policy_wait_ms", 0.0)),
|
||||
"engine_dispatch_wait_ms": float(prepare_item.get("engine_dispatch_wait_ms", 0.0)),
|
||||
"decode_admission_wait_ms": float(worker_item.get("decode_admission_wait_ms", 0.0)),
|
||||
"queue_wait_ms": float(worker_item.get("queue_wait_ms", 0.0)),
|
||||
"prefill_ms": float(worker_item.get("prefill_ms", 0.0)),
|
||||
"merge_ms": float(worker_item.get("merge_ms", 0.0)),
|
||||
"decode_ms": float(worker_item.get("decode_ms", 0.0)),
|
||||
"finalize_wait_ms": float(worker_item.get("finalize_wait_ms", 0.0)),
|
||||
"synth_ms": float(worker_item.get("synth_ms", 0.0)),
|
||||
"worker_total_ms": float(worker_item.get("worker_total_ms", 0.0)),
|
||||
"decode_steps": int(worker_item.get("decode_steps", 0)),
|
||||
"semantic_len": int(worker_item.get("semantic_len", 0)),
|
||||
"finish_reason": worker_item.get("finish_reason"),
|
||||
"norm_text": prepare_profile.get("norm_text"),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def build_direct_scheduler_profile(
|
||||
*,
|
||||
backend: str,
|
||||
request_start: float,
|
||||
response_ready_at: float,
|
||||
audio_bytes: int,
|
||||
sample_rate: int,
|
||||
segment_texts: Sequence[str],
|
||||
prepare_profiles: Sequence[Dict[str, Any]],
|
||||
worker_profiles: Sequence[Dict[str, Any]],
|
||||
pack_ms: float,
|
||||
response_overhead_ms: float,
|
||||
) -> Dict[str, Any]:
|
||||
segment_trace = build_direct_segment_trace(segment_texts, prepare_profiles, worker_profiles)
|
||||
prepare_profile_dicts = [dict(item.get("prepare_profile", {})) for item in prepare_profiles]
|
||||
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
|
||||
prepare_wall_ms = sum_profile_field(prepare_profiles, "prepare_wall_ms")
|
||||
prepare_profile_total_ms = sum_profile_field(prepare_profiles, "prepare_profile_total_ms")
|
||||
engine_policy_wait_ms = sum_profile_field(prepare_profiles, "engine_policy_wait_ms")
|
||||
engine_dispatch_wait_ms = sum_profile_field(prepare_profiles, "engine_dispatch_wait_ms")
|
||||
decode_admission_wait_ms = sum_profile_field(worker_profiles, "decode_admission_wait_ms")
|
||||
queue_wait_ms = sum_profile_field(worker_profiles, "queue_wait_ms")
|
||||
prefill_ms = sum_profile_field(worker_profiles, "prefill_ms")
|
||||
merge_ms = sum_profile_field(worker_profiles, "merge_ms")
|
||||
decode_ms = sum_profile_field(worker_profiles, "decode_ms")
|
||||
finalize_wait_ms = sum_profile_field(worker_profiles, "finalize_wait_ms")
|
||||
synth_ms = sum_profile_field(worker_profiles, "synth_ms")
|
||||
worker_total_ms = sum_profile_field(worker_profiles, "worker_total_ms")
|
||||
decode_steps = sum(int(item.get("decode_steps", 0)) for item in worker_profiles)
|
||||
semantic_len = sum(int(item.get("semantic_len", 0)) for item in worker_profiles)
|
||||
request_other_ms = max(
|
||||
0.0,
|
||||
request_total_ms - prepare_wall_ms - engine_policy_wait_ms - worker_total_ms - pack_ms - response_overhead_ms,
|
||||
)
|
||||
return {
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"segment_count": len(segment_texts),
|
||||
"sample_rate": int(sample_rate),
|
||||
"audio_bytes": int(audio_bytes),
|
||||
"request_total_ms": request_total_ms,
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"engine_policy_wait_ms": engine_policy_wait_ms,
|
||||
"engine_dispatch_wait_ms": engine_dispatch_wait_ms,
|
||||
"decode_admission_wait_ms": decode_admission_wait_ms,
|
||||
"queue_wait_ms": queue_wait_ms,
|
||||
"prefill_ms": prefill_ms,
|
||||
"merge_ms": merge_ms,
|
||||
"decode_ms": decode_ms,
|
||||
"finalize_wait_ms": finalize_wait_ms,
|
||||
"synth_ms": synth_ms,
|
||||
"pack_ms": pack_ms,
|
||||
"response_overhead_ms": response_overhead_ms,
|
||||
"worker_total_ms": worker_total_ms,
|
||||
"request_other_ms": request_other_ms,
|
||||
"decode_steps": decode_steps,
|
||||
"semantic_len": semantic_len,
|
||||
"prepare_segments": list(prepare_profiles),
|
||||
"worker_segments": list(worker_profiles),
|
||||
"segment_trace": segment_trace,
|
||||
"prepare_aggregate": aggregate_numeric_dicts(prepare_profile_dicts),
|
||||
}
|
||||
|
||||
|
||||
def build_legacy_direct_profile(
|
||||
*,
|
||||
backend: str,
|
||||
fallback_reason: str | None,
|
||||
request_start: float,
|
||||
finished_at: float,
|
||||
sample_rate: int | None = None,
|
||||
audio_bytes: int = 0,
|
||||
pack_ms: float = 0.0,
|
||||
chunk_count: int = 0,
|
||||
stream_total_bytes: int = 0,
|
||||
first_chunk_ms: float | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
request_total_ms = max(0.0, (finished_at - request_start) * 1000.0)
|
||||
legacy_infer_ms = max(0.0, request_total_ms - pack_ms)
|
||||
return {
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"fallback_reason": fallback_reason,
|
||||
"request_total_ms": request_total_ms,
|
||||
"prepare_ms": 0.0,
|
||||
"queue_wait_ms": 0.0,
|
||||
"prefill_ms": 0.0,
|
||||
"merge_ms": 0.0,
|
||||
"decode_ms": 0.0,
|
||||
"finalize_wait_ms": 0.0,
|
||||
"synth_ms": 0.0,
|
||||
"pack_ms": pack_ms,
|
||||
"worker_total_ms": legacy_infer_ms,
|
||||
"request_other_ms": 0.0,
|
||||
"legacy_infer_ms": legacy_infer_ms,
|
||||
"sample_rate": int(sample_rate) if sample_rate is not None else None,
|
||||
"audio_bytes": int(audio_bytes),
|
||||
"chunk_count": int(chunk_count),
|
||||
"stream_total_bytes": int(stream_total_bytes),
|
||||
"first_chunk_ms": None if first_chunk_ms is None else float(first_chunk_ms),
|
||||
}
|
||||
|
||||
|
||||
def build_scheduler_submit_profile(
|
||||
*,
|
||||
backend: str,
|
||||
request_start: float,
|
||||
response_ready_at: float,
|
||||
audio_bytes: int,
|
||||
sample_rate: int,
|
||||
prepare_spec_build_ms: float,
|
||||
prepare_wall_ms: float,
|
||||
prepare_executor_queue_ms: float,
|
||||
prepare_executor_run_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
prepare_profile_wall_ms: float,
|
||||
prepare_other_ms: float,
|
||||
engine_policy_wait_ms: float,
|
||||
api_after_prepare_ms: float,
|
||||
api_wait_result_ms: float,
|
||||
pack_ms: float,
|
||||
response_overhead_ms: float,
|
||||
worker_profile: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
worker_total_ms = float(worker_profile.get("worker_total_ms", 0.0))
|
||||
request_total_ms = max(0.0, (response_ready_at - request_start) * 1000.0)
|
||||
request_other_ms = max(
|
||||
0.0,
|
||||
request_total_ms
|
||||
- prepare_wall_ms
|
||||
- engine_policy_wait_ms
|
||||
- api_after_prepare_ms
|
||||
- worker_total_ms
|
||||
- api_wait_result_ms
|
||||
- pack_ms,
|
||||
)
|
||||
result = {
|
||||
"backend": backend,
|
||||
"backend_mode": backend,
|
||||
"audio_bytes": int(audio_bytes),
|
||||
"sample_rate": int(sample_rate),
|
||||
"prepare_spec_build_ms": prepare_spec_build_ms,
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_executor_queue_ms": prepare_executor_queue_ms,
|
||||
"prepare_executor_run_ms": prepare_executor_run_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile_wall_ms": prepare_profile_wall_ms,
|
||||
"prepare_other_ms": prepare_other_ms,
|
||||
"engine_policy_wait_ms": float(engine_policy_wait_ms),
|
||||
"api_after_prepare_ms": api_after_prepare_ms,
|
||||
"api_wait_result_ms": api_wait_result_ms,
|
||||
"pack_ms": pack_ms,
|
||||
"response_overhead_ms": response_overhead_ms,
|
||||
"request_total_ms": request_total_ms,
|
||||
"request_other_ms": request_other_ms,
|
||||
}
|
||||
result.update({key: value for key, value in worker_profile.items()})
|
||||
return result
|
||||
|
||||
|
||||
def format_ms_header(value: Any) -> str:
|
||||
return f"{float(value):.3f}"
|
||||
|
||||
|
||||
def build_scheduler_submit_headers(
|
||||
*,
|
||||
request_id: str,
|
||||
media_type: str,
|
||||
sample_rate: int,
|
||||
profile: Dict[str, Any],
|
||||
) -> Dict[str, str]:
|
||||
prepare_profile = dict(profile.get("prepare_profile", {}))
|
||||
headers = {
|
||||
"X-Request-Id": request_id,
|
||||
"X-Semantic-Len": str(int(profile.get("semantic_len", 0))),
|
||||
"X-Finish-Reason": str(profile.get("finish_reason", "unknown")),
|
||||
"X-Queue-Wait-Ms": format_ms_header(profile.get("queue_wait_ms", 0.0)),
|
||||
"X-Decode-Admission-Wait-Ms": format_ms_header(profile.get("decode_admission_wait_ms", 0.0)),
|
||||
"X-Engine-Policy-Wait-Ms": format_ms_header(profile.get("engine_policy_wait_ms", 0.0)),
|
||||
"X-Engine-Dispatch-Wait-Ms": format_ms_header(profile.get("engine_dispatch_wait_ms", 0.0)),
|
||||
"X-Prepare-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)),
|
||||
"X-Prepare-Wall-Ms": format_ms_header(profile.get("prepare_wall_ms", 0.0)),
|
||||
"X-Prepare-Spec-Build-Ms": format_ms_header(profile.get("prepare_spec_build_ms", 0.0)),
|
||||
"X-Prepare-Executor-Queue-Ms": format_ms_header(profile.get("prepare_executor_queue_ms", 0.0)),
|
||||
"X-Prepare-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prepare_admission_wait_ms", 0.0)),
|
||||
"X-Prepare-Executor-Run-Ms": format_ms_header(profile.get("prepare_executor_run_ms", 0.0)),
|
||||
"X-Prepare-Profile-Total-Ms": format_ms_header(profile.get("prepare_profile_total_ms", 0.0)),
|
||||
"X-Prepare-Profile-Wall-Ms": format_ms_header(profile.get("prepare_profile_wall_ms", 0.0)),
|
||||
"X-Prepare-Other-Ms": format_ms_header(profile.get("prepare_other_ms", 0.0)),
|
||||
"X-Api-After-Prepare-Ms": format_ms_header(profile.get("api_after_prepare_ms", 0.0)),
|
||||
"X-Prefill-Ms": format_ms_header(profile.get("prefill_ms", 0.0)),
|
||||
"X-Merge-Ms": format_ms_header(profile.get("merge_ms", 0.0)),
|
||||
"X-Decode-Ms": format_ms_header(profile.get("decode_ms", 0.0)),
|
||||
"X-Finalize-Wait-Ms": format_ms_header(profile.get("finalize_wait_ms", 0.0)),
|
||||
"X-Synth-Ms": format_ms_header(profile.get("synth_ms", 0.0)),
|
||||
"X-Worker-Residual-Ms": format_ms_header(profile.get("worker_residual_ms", 0.0)),
|
||||
"X-Worker-Other-Ms": format_ms_header(profile.get("worker_other_ms", 0.0)),
|
||||
"X-Pack-Ms": format_ms_header(profile.get("pack_ms", 0.0)),
|
||||
"X-Worker-Total-Ms": format_ms_header(profile.get("worker_total_ms", 0.0)),
|
||||
"X-Api-Wait-Result-Ms": format_ms_header(profile.get("api_wait_result_ms", 0.0)),
|
||||
"X-Decode-Steps": str(int(profile.get("decode_steps", 0))),
|
||||
"X-Sample-Rate": str(int(sample_rate)),
|
||||
"X-Response-Overhead-Ms": format_ms_header(profile.get("response_overhead_ms", 0.0)),
|
||||
"X-Request-Other-Ms": format_ms_header(profile.get("request_other_ms", 0.0)),
|
||||
"X-Request-Total-Ms": format_ms_header(profile.get("request_total_ms", 0.0)),
|
||||
}
|
||||
headers.update(
|
||||
{
|
||||
"X-Prepare-Prompt-Text-Ms": format_ms_header(prepare_profile.get("prompt_text_features_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-Ms": format_ms_header(prepare_profile.get("text_features_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_preprocess_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-CPU-Preprocess-Ms": format_ms_header(prepare_profile.get("text_cpu_preprocess_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_cpu_queue_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-CPU-Queue-Ms": format_ms_header(prepare_profile.get("text_cpu_queue_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("prompt_text_feature_queue_ms", 0.0)),
|
||||
"X-Prepare-Target-Text-Feature-Queue-Ms": format_ms_header(prepare_profile.get("text_feature_queue_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_admission_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Admission-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_admission_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_queue_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Queue-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_queue_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_collect_wait_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Batch-Collect-Wait-Ms": format_ms_header(prepare_profile.get("text_bert_batch_collect_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Forward-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_forward_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Forward-Ms": format_ms_header(prepare_profile.get("text_bert_forward_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_enqueue_peak", 0.0))),
|
||||
"X-Prepare-Target-Bert-Pending-On-Enqueue-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_enqueue_peak", 0.0))),
|
||||
"X-Prepare-Prompt-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("prompt_text_bert_pending_depth_on_collect_peak", 0.0))),
|
||||
"X-Prepare-Target-Bert-Pending-On-Collect-Peak": str(int(prepare_profile.get("text_bert_pending_depth_on_collect_peak", 0.0))),
|
||||
"X-Prepare-Prompt-Bert-High-Pressure-Peak": str(int(prepare_profile.get("prompt_text_bert_high_pressure_mode_peak", 0.0))),
|
||||
"X-Prepare-Target-Bert-High-Pressure-Peak": str(int(prepare_profile.get("text_bert_high_pressure_mode_peak", 0.0))),
|
||||
"X-Prepare-Prompt-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("prompt_text_bert_batch_window_ms", 0.0)),
|
||||
"X-Prepare-Target-Bert-Batch-Window-Ms": format_ms_header(prepare_profile.get("text_bert_batch_window_ms", 0.0)),
|
||||
"X-Prepare-Text-Pair-Wall-Ms": format_ms_header(prepare_profile.get("text_feature_pair_ms", 0.0)),
|
||||
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
|
||||
"X-Prepare-Engine-GPU-Queue-Wait-Ms": format_ms_header(prepare_profile.get("engine_gpu_prepare_queue_wait_ms", 0.0)),
|
||||
"X-Prepare-Engine-GPU-Batch-Size": str(int(prepare_profile.get("engine_gpu_prepare_batch_size", 0.0))),
|
||||
"X-Prepare-Audio-Load-Ms": format_ms_header(prepare_profile.get("audio_load_ms", 0.0)),
|
||||
"X-Prepare-Audio-Stage-Wait-Ms": format_ms_header(prepare_profile.get("audio_stage_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-Ms": format_ms_header(prepare_profile.get("prompt_semantic_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-Wait-Ms": format_ms_header(prepare_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-CPU-Ms": format_ms_header(prepare_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"X-Prepare-Prompt-Semantic-Forward-Ms": format_ms_header(prepare_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"X-Prepare-Ref-Spec-Ms": format_ms_header(prepare_profile.get("ref_spec_ms", 0.0)),
|
||||
"X-Prepare-Ref-Spec-Wait-Ms": format_ms_header(prepare_profile.get("ref_spec_wait_ms", 0.0)),
|
||||
"X-Prepare-Ref-Bundle-Ms": format_ms_header(prepare_profile.get("ref_audio_bundle_ms", 0.0)),
|
||||
"X-Prepare-Tensorize-Ms": format_ms_header(prepare_profile.get("tensorize_ms", 0.0)),
|
||||
"X-Prepare-Inflight-On-Enter": str(int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))),
|
||||
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
|
||||
def build_scheduler_debug_request_profile(
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
item: T2SFinishedItem,
|
||||
batch_request_count: int,
|
||||
prepare_batch_wall_ms: float,
|
||||
decode_batch_wall_ms: float,
|
||||
batch_request_total_ms: float,
|
||||
) -> Dict[str, Any]:
|
||||
prepare_profile = dict(state.prepare_profile)
|
||||
prepare_wall_ms = float(prepare_profile.get("wall_total_ms", 0.0))
|
||||
return {
|
||||
"backend": "scheduler_debug",
|
||||
"backend_mode": "scheduler_debug",
|
||||
"batch_request_count": int(batch_request_count),
|
||||
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
|
||||
"batch_decode_wall_ms": float(decode_batch_wall_ms),
|
||||
"batch_request_total_ms": float(batch_request_total_ms),
|
||||
"prepare_ms": prepare_wall_ms,
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": float(prepare_profile.get("wall_total_ms", prepare_wall_ms)),
|
||||
"prepare_profile": prepare_profile,
|
||||
"decode_steps": int(item.finish_idx),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_reason": item.finish_reason,
|
||||
"norm_text": state.norm_text,
|
||||
"norm_prompt_text": state.norm_prompt_text,
|
||||
}
|
||||
|
||||
|
||||
def build_scheduler_debug_batch_profile(
|
||||
*,
|
||||
request_count: int,
|
||||
max_steps: int,
|
||||
prepare_batch_wall_ms: float,
|
||||
decode_batch_wall_ms: float,
|
||||
request_total_ms: float,
|
||||
finished_items: Sequence[T2SFinishedItem],
|
||||
) -> Dict[str, Any]:
|
||||
finish_reason_counts: Dict[str, int] = {}
|
||||
total_semantic_len = 0
|
||||
for item in finished_items:
|
||||
finish_reason_counts[item.finish_reason] = finish_reason_counts.get(item.finish_reason, 0) + 1
|
||||
total_semantic_len += int(item.semantic_tokens.shape[0])
|
||||
return {
|
||||
"request_count": int(request_count),
|
||||
"max_steps": int(max_steps),
|
||||
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
|
||||
"decode_batch_wall_ms": float(decode_batch_wall_ms),
|
||||
"request_total_ms": float(request_total_ms),
|
||||
"total_semantic_len": int(total_semantic_len),
|
||||
"finish_reason_counts": finish_reason_counts,
|
||||
}
|
||||
189
GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py
Normal file
189
GPT_SoVITS/TTS_infer_pack/unified_engine_api_request.py
Normal file
@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import NormalizedEngineRequest, ReferenceRegistry
|
||||
|
||||
|
||||
def normalize_lang(value: str | None) -> str | None:
|
||||
if value in [None, ""]:
|
||||
return value
|
||||
return str(value).lower()
|
||||
|
||||
|
||||
def apply_default_reference(reference_registry: ReferenceRegistry, req: dict) -> dict:
|
||||
normalized = dict(req)
|
||||
default_ref = reference_registry.get_default()
|
||||
if normalized.get("ref_audio_path") in [None, ""] and default_ref.ref_audio_path not in [None, ""]:
|
||||
normalized["ref_audio_path"] = default_ref.ref_audio_path
|
||||
if "text_lang" in normalized:
|
||||
normalized["text_lang"] = normalize_lang(normalized.get("text_lang"))
|
||||
if "prompt_lang" in normalized:
|
||||
normalized["prompt_lang"] = normalize_lang(normalized.get("prompt_lang"))
|
||||
return normalized
|
||||
|
||||
|
||||
def check_params(tts: TTS, cut_method_names: Sequence[str], req: dict) -> Optional[str]:
|
||||
text = req.get("text", "")
|
||||
text_lang = req.get("text_lang", "")
|
||||
ref_audio_path = req.get("ref_audio_path", "")
|
||||
media_type = req.get("media_type", "wav")
|
||||
prompt_lang = req.get("prompt_lang", "")
|
||||
text_split_method = req.get("text_split_method", "cut5")
|
||||
|
||||
if ref_audio_path in [None, ""]:
|
||||
return "ref_audio_path is required"
|
||||
if text in [None, ""]:
|
||||
return "text is required"
|
||||
if text_lang in [None, ""]:
|
||||
return "text_lang is required"
|
||||
if text_lang.lower() not in tts.configs.languages:
|
||||
return f"text_lang: {text_lang} is not supported in version {tts.configs.version}"
|
||||
if prompt_lang in [None, ""]:
|
||||
return "prompt_lang is required"
|
||||
if prompt_lang.lower() not in tts.configs.languages:
|
||||
return f"prompt_lang: {prompt_lang} is not supported in version {tts.configs.version}"
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return f"media_type: {media_type} is not supported"
|
||||
if text_split_method not in cut_method_names:
|
||||
return f"text_split_method:{text_split_method} is not supported"
|
||||
return None
|
||||
|
||||
|
||||
def base_request_defaults() -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": None,
|
||||
"text": None,
|
||||
"text_lang": None,
|
||||
"ref_audio_path": None,
|
||||
"aux_ref_audio_paths": None,
|
||||
"prompt_text": "",
|
||||
"prompt_lang": None,
|
||||
"top_k": 15,
|
||||
"top_p": 1.0,
|
||||
"temperature": 1.0,
|
||||
"text_split_method": "cut5",
|
||||
"batch_size": 1,
|
||||
"batch_threshold": 0.75,
|
||||
"speed_factor": 1.0,
|
||||
"split_bucket": False,
|
||||
"fragment_interval": 0.3,
|
||||
"seed": -1,
|
||||
"media_type": "wav",
|
||||
"streaming_mode": False,
|
||||
"return_fragment": False,
|
||||
"fixed_length_chunk": False,
|
||||
"response_streaming": False,
|
||||
"parallel_infer": False,
|
||||
"repetition_penalty": 1.35,
|
||||
"sample_steps": 32,
|
||||
"super_sampling": False,
|
||||
"overlap_length": 2,
|
||||
"min_chunk_length": 16,
|
||||
"early_stop_num": -1,
|
||||
"ready_step": 0,
|
||||
"timeout_sec": None,
|
||||
}
|
||||
|
||||
|
||||
def normalize_streaming_mode(req: dict) -> dict:
|
||||
normalized = dict(req)
|
||||
streaming_mode = normalized.get("streaming_mode", False)
|
||||
return_fragment = normalized.get("return_fragment", False)
|
||||
if streaming_mode is False:
|
||||
normalized["streaming_mode"] = False
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 0:
|
||||
normalized["streaming_mode"] = False
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 1 or streaming_mode is True:
|
||||
normalized["streaming_mode"] = False
|
||||
normalized["return_fragment"] = True
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 2:
|
||||
normalized["streaming_mode"] = True
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = False
|
||||
elif streaming_mode == 3:
|
||||
normalized["streaming_mode"] = True
|
||||
normalized["return_fragment"] = False
|
||||
normalized["fixed_length_chunk"] = True
|
||||
else:
|
||||
raise ValueError("the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)")
|
||||
normalized["response_streaming"] = bool(normalized["streaming_mode"] or normalized["return_fragment"] or return_fragment)
|
||||
return normalized
|
||||
|
||||
|
||||
def is_aux_ref_enabled(aux_ref_audio_paths: List[str] | None) -> bool:
|
||||
return aux_ref_audio_paths not in [None, [], ()]
|
||||
|
||||
|
||||
def select_direct_backend(normalized: NormalizedEngineRequest) -> Tuple[str, str | None]:
|
||||
return "scheduler_v1_direct", None
|
||||
|
||||
|
||||
def normalize_engine_request(
|
||||
*,
|
||||
tts: TTS,
|
||||
cut_method_names: Sequence[str],
|
||||
reference_registry: ReferenceRegistry,
|
||||
payload: dict | NormalizedEngineRequest,
|
||||
request_id: str | None = None,
|
||||
normalize_streaming: bool = False,
|
||||
error_prefix: str = "request 参数非法: ",
|
||||
) -> NormalizedEngineRequest:
|
||||
if isinstance(payload, NormalizedEngineRequest):
|
||||
normalized_payload = payload.to_payload()
|
||||
else:
|
||||
normalized_payload = base_request_defaults()
|
||||
normalized_payload.update(dict(payload))
|
||||
if request_id not in [None, ""]:
|
||||
normalized_payload["request_id"] = str(request_id)
|
||||
elif normalized_payload.get("request_id") in [None, ""]:
|
||||
raise ValueError("request_id is required after normalization")
|
||||
normalized_payload = apply_default_reference(reference_registry, normalized_payload)
|
||||
if normalize_streaming:
|
||||
normalized_payload = normalize_streaming_mode(normalized_payload)
|
||||
error = check_params(tts, cut_method_names, normalized_payload)
|
||||
if error is not None:
|
||||
raise ValueError(f"{error_prefix}{error}")
|
||||
timeout_sec = normalized_payload.get("timeout_sec")
|
||||
parsed_timeout = None if timeout_sec in [None, ""] else float(timeout_sec)
|
||||
aux_ref_audio_paths = normalized_payload.get("aux_ref_audio_paths")
|
||||
normalized_aux_ref_audio_paths = None if aux_ref_audio_paths in [None, "", []] else [str(item) for item in aux_ref_audio_paths]
|
||||
return NormalizedEngineRequest(
|
||||
request_id=str(normalized_payload["request_id"]),
|
||||
text=str(normalized_payload["text"]),
|
||||
text_lang=str(normalized_payload["text_lang"]),
|
||||
ref_audio_path=str(normalized_payload["ref_audio_path"]),
|
||||
prompt_lang=str(normalized_payload["prompt_lang"]),
|
||||
prompt_text="" if normalized_payload.get("prompt_text") is None else str(normalized_payload.get("prompt_text")),
|
||||
aux_ref_audio_paths=normalized_aux_ref_audio_paths,
|
||||
top_k=int(normalized_payload["top_k"]),
|
||||
top_p=float(normalized_payload["top_p"]),
|
||||
temperature=float(normalized_payload["temperature"]),
|
||||
repetition_penalty=float(normalized_payload["repetition_penalty"]),
|
||||
early_stop_num=int(normalized_payload.get("early_stop_num", -1)),
|
||||
ready_step=int(normalized_payload.get("ready_step", 0)),
|
||||
text_split_method=str(normalized_payload["text_split_method"]),
|
||||
batch_size=int(normalized_payload["batch_size"]),
|
||||
batch_threshold=float(normalized_payload["batch_threshold"]),
|
||||
split_bucket=bool(normalized_payload["split_bucket"]),
|
||||
speed_factor=float(normalized_payload["speed_factor"]),
|
||||
fragment_interval=float(normalized_payload["fragment_interval"]),
|
||||
seed=int(normalized_payload["seed"]),
|
||||
media_type=str(normalized_payload["media_type"]),
|
||||
streaming_mode=normalized_payload["streaming_mode"],
|
||||
return_fragment=bool(normalized_payload.get("return_fragment", False)),
|
||||
fixed_length_chunk=bool(normalized_payload.get("fixed_length_chunk", False)),
|
||||
response_streaming=bool(normalized_payload.get("response_streaming", False)),
|
||||
parallel_infer=bool(normalized_payload["parallel_infer"]),
|
||||
sample_steps=int(normalized_payload["sample_steps"]),
|
||||
super_sampling=bool(normalized_payload["super_sampling"]),
|
||||
overlap_length=int(normalized_payload["overlap_length"]),
|
||||
min_chunk_length=int(normalized_payload["min_chunk_length"]),
|
||||
timeout_sec=parsed_timeout,
|
||||
)
|
||||
340
GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py
Normal file
340
GPT_SoVITS/TTS_infer_pack/unified_engine_api_scheduler.py
Normal file
@ -0,0 +1,340 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_audio import pack_audio, set_scheduler_seed
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, NormalizedEngineRequest, SchedulerDebugExecution, SchedulerSubmitExecution
|
||||
|
||||
|
||||
class EngineApiSchedulerFlow:
|
||||
def __init__(self, api: Any) -> None:
|
||||
self.api = api
|
||||
|
||||
def _build_scheduler_request_specs(self, request_items: List[dict]) -> List[SchedulerRequestSpec]:
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, payload in enumerate(request_items):
|
||||
normalized = self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
|
||||
error_prefix=f"request[{index}] 参数非法: ",
|
||||
)
|
||||
specs.append(normalized.to_scheduler_spec())
|
||||
return specs
|
||||
|
||||
def _build_scheduler_submit_spec(self, payload: dict | NormalizedEngineRequest) -> SchedulerRequestSpec:
|
||||
normalized = self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=(
|
||||
payload.request_id
|
||||
if isinstance(payload, NormalizedEngineRequest)
|
||||
else str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}")
|
||||
),
|
||||
)
|
||||
return normalized.to_scheduler_spec()
|
||||
|
||||
@staticmethod
|
||||
def _summarize_scheduler_states(states: List[T2SRequestState]) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
"ready_step": int(state.ready_step),
|
||||
"ref_audio_path": str(state.ref_audio_path),
|
||||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||||
"all_phone_len": int(state.all_phones.shape[0]),
|
||||
"bert_len": int(state.all_bert_features.shape[-1]),
|
||||
"norm_text": state.norm_text,
|
||||
}
|
||||
for state in states
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _summarize_scheduler_finished(items: List[T2SFinishedItem]) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
|
||||
async def run_scheduler_debug(self, request_items: List[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
request_start = time.perf_counter()
|
||||
set_scheduler_seed(seed)
|
||||
normalized_requests: List[NormalizedEngineRequest] = []
|
||||
for index, payload in enumerate(request_items):
|
||||
normalized_requests.append(
|
||||
self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=str(payload.get("request_id") or f"req_{index:03d}"),
|
||||
error_prefix=f"request[{index}] 参数非法: ",
|
||||
)
|
||||
)
|
||||
specs = [normalized.to_scheduler_spec() for normalized in normalized_requests]
|
||||
request_ids = [normalized.request_id for normalized in normalized_requests]
|
||||
for normalized, spec in zip(normalized_requests, specs):
|
||||
self.api._register_request_state(
|
||||
request_id=normalized.request_id,
|
||||
api_mode="scheduler_debug",
|
||||
backend="scheduler_debug",
|
||||
media_type=normalized.media_type,
|
||||
response_streaming=False,
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(normalized.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_debug"})
|
||||
self.api._update_request_state(normalized.request_id, EngineStatus.CPU_PREPARING, None)
|
||||
prepare_started_at = time.perf_counter()
|
||||
original_worker_max_steps = int(self.api.scheduler_worker.max_steps)
|
||||
original_decode_max_steps = int(self.api.scheduler_worker.decode_executor.max_steps)
|
||||
try:
|
||||
self.api.scheduler_worker.max_steps = int(max_steps)
|
||||
self.api.scheduler_worker.decode_executor.max_steps = int(max_steps)
|
||||
prepared_payloads = await asyncio.gather(
|
||||
*[
|
||||
self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=time.perf_counter(),
|
||||
engine_request_id=normalized.request_id,
|
||||
)
|
||||
for normalized, spec in zip(normalized_requests, specs)
|
||||
]
|
||||
)
|
||||
except Exception as exc:
|
||||
for request_id in request_ids:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
finally:
|
||||
self.api.scheduler_worker.max_steps = int(original_worker_max_steps)
|
||||
self.api.scheduler_worker.decode_executor.max_steps = int(original_decode_max_steps)
|
||||
prepare_finished_at = time.perf_counter()
|
||||
prepare_batch_wall_ms = max(0.0, (prepare_finished_at - prepare_started_at) * 1000.0)
|
||||
states = [payload[0] for payload in prepared_payloads]
|
||||
for state in states:
|
||||
self.api._update_request_state(
|
||||
state.request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
"norm_text": state.norm_text,
|
||||
"norm_prompt_text": state.norm_prompt_text,
|
||||
},
|
||||
)
|
||||
decode_started_at = time.perf_counter()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
done_futures: List[asyncio.Future] = []
|
||||
for normalized, state in zip(normalized_requests, states):
|
||||
done_future = loop.create_future()
|
||||
done_futures.append(done_future)
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=normalized.media_type,
|
||||
super_sampling=bool(normalized.super_sampling),
|
||||
prepare_wall_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
prepare_profile_total_ms=float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=normalized.request_id,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
timeout_candidates = [float(item.timeout_sec) for item in normalized_requests if item.timeout_sec not in [None, ""]]
|
||||
timeout_sec = max(timeout_candidates) if timeout_candidates else 60.0
|
||||
jobs = list(await asyncio.wait_for(asyncio.gather(*done_futures), timeout=float(timeout_sec)))
|
||||
except Exception as exc:
|
||||
for request_id in request_ids:
|
||||
self.api._fail_request_state(request_id, str(exc))
|
||||
raise
|
||||
decode_finished_at = time.perf_counter()
|
||||
decode_batch_wall_ms = max(0.0, (decode_finished_at - decode_started_at) * 1000.0)
|
||||
request_total_ms = max(0.0, (decode_finished_at - request_start) * 1000.0)
|
||||
request_profiles: List[Dict[str, Any]] = []
|
||||
finished: List[Dict[str, Any]] = []
|
||||
finish_reason_counts: Dict[str, int] = {}
|
||||
total_semantic_len = 0
|
||||
for state, job in zip(states, jobs):
|
||||
if job.error is not None:
|
||||
self.api._fail_request_state(state.request_id, str(job.error))
|
||||
raise RuntimeError(str(job.error))
|
||||
if job.result is None:
|
||||
self.api._fail_request_state(state.request_id, "scheduler_debug finished without result")
|
||||
raise RuntimeError(f"{state.request_id} finished without result")
|
||||
job_result = dict(job.result)
|
||||
request_profile = {
|
||||
**job_result,
|
||||
"backend": "scheduler_debug",
|
||||
"backend_mode": "scheduler_debug",
|
||||
"batch_request_count": int(len(states)),
|
||||
"batch_prepare_wall_ms": float(prepare_batch_wall_ms),
|
||||
"batch_decode_wall_ms": float(decode_batch_wall_ms),
|
||||
"batch_request_total_ms": float(request_total_ms),
|
||||
"prepare_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
"prepare_wall_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
"prepare_profile_total_ms": float(state.prepare_profile.get("wall_total_ms", 0.0)),
|
||||
"prepare_profile": dict(state.prepare_profile),
|
||||
"norm_text": state.norm_text,
|
||||
"norm_prompt_text": state.norm_prompt_text,
|
||||
}
|
||||
request_profiles.append({"request_id": state.request_id, "profile": dict(request_profile)})
|
||||
self.api._merge_request_state_profile(state.request_id, request_profile)
|
||||
semantic_len = int(job_result.get("semantic_len", 0))
|
||||
finish_reason = str(job_result.get("finish_reason", "unknown"))
|
||||
finished.append(
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
"semantic_len": semantic_len,
|
||||
"finish_idx": int(job_result.get("finish_idx", job_result.get("decode_steps", 0))),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
)
|
||||
finish_reason_counts[finish_reason] = finish_reason_counts.get(finish_reason, 0) + 1
|
||||
total_semantic_len += semantic_len
|
||||
return SchedulerDebugExecution(
|
||||
payload={
|
||||
"message": "success",
|
||||
"request_count": len(states),
|
||||
"max_steps": int(max_steps),
|
||||
"batch_profile": {
|
||||
"request_count": int(len(states)),
|
||||
"max_steps": int(max_steps),
|
||||
"prepare_batch_wall_ms": float(prepare_batch_wall_ms),
|
||||
"decode_batch_wall_ms": float(decode_batch_wall_ms),
|
||||
"request_total_ms": float(request_total_ms),
|
||||
"total_semantic_len": int(total_semantic_len),
|
||||
"finish_reason_counts": finish_reason_counts,
|
||||
},
|
||||
"requests": self._summarize_scheduler_states(states),
|
||||
"finished": finished,
|
||||
"request_profiles": request_profiles,
|
||||
"request_traces": self.api._collect_request_summaries(request_ids),
|
||||
}
|
||||
)
|
||||
|
||||
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
||||
request_start = time.perf_counter()
|
||||
prepare_start = request_start
|
||||
normalized = self.api._normalize_engine_request(
|
||||
payload,
|
||||
request_id=str(payload.get("request_id") or f"job_{uuid.uuid4().hex[:12]}"),
|
||||
)
|
||||
spec = self._build_scheduler_submit_spec(normalized)
|
||||
deadline_ts = None
|
||||
timeout_sec = normalized.timeout_sec
|
||||
if timeout_sec is not None:
|
||||
try:
|
||||
deadline_ts = request_start + float(timeout_sec)
|
||||
except Exception:
|
||||
deadline_ts = None
|
||||
self.api._register_request_state(
|
||||
request_id=spec.request_id,
|
||||
api_mode="scheduler_submit",
|
||||
backend="scheduler_v1",
|
||||
media_type=normalized.media_type,
|
||||
response_streaming=False,
|
||||
deadline_ts=deadline_ts,
|
||||
meta=self.api._build_request_meta(normalized.to_payload()),
|
||||
)
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.VALIDATED, {"request_source": "scheduler_submit"})
|
||||
spec_ready_at = time.perf_counter()
|
||||
prepare_spec_build_ms = max(0.0, (spec_ready_at - prepare_start) * 1000.0)
|
||||
self.api._update_request_state(spec.request_id, EngineStatus.CPU_PREPARING, {"prepare_spec_build_ms": prepare_spec_build_ms})
|
||||
try:
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = await self.api._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=spec_ready_at,
|
||||
engine_request_id=spec.request_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(spec.request_id, str(exc))
|
||||
raise
|
||||
prepare_wall_ms = max(0.0, (prepare_exec_finished_at - spec_ready_at) * 1000.0)
|
||||
prepare_executor_queue_ms = max(0.0, (prepare_exec_started_at - spec_ready_at) * 1000.0)
|
||||
prepare_executor_run_ms = max(0.0, (prepare_exec_finished_at - prepare_exec_started_at) * 1000.0)
|
||||
prepare_profile = dict(state.prepare_profile)
|
||||
prepare_profile_total_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
prepare_profile_wall_ms = float(prepare_profile.get("wall_total_ms", prepare_wall_ms))
|
||||
prepare_other_ms = max(0.0, prepare_wall_ms - prepare_spec_build_ms - prepare_executor_queue_ms - prepare_executor_run_ms)
|
||||
self.api._update_request_state(
|
||||
spec.request_id,
|
||||
EngineStatus.READY_FOR_PREFILL,
|
||||
{
|
||||
"prepare_wall_ms": prepare_wall_ms,
|
||||
"prepare_profile_total_ms": prepare_profile_total_ms,
|
||||
"prepare_profile": prepare_profile,
|
||||
},
|
||||
)
|
||||
api_after_prepare_start = time.perf_counter()
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
await self.api._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=float(normalized.speed_factor),
|
||||
sample_steps=int(normalized.sample_steps),
|
||||
media_type=normalized.media_type,
|
||||
super_sampling=bool(normalized.super_sampling),
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=spec.request_id,
|
||||
timeout_sec=normalized.timeout_sec,
|
||||
)
|
||||
api_after_prepare_ms = max(0.0, (time.perf_counter() - api_after_prepare_start) * 1000.0)
|
||||
try:
|
||||
job = await asyncio.wait_for(done_future, timeout=float(normalized.timeout_sec if normalized.timeout_sec is not None else 30.0))
|
||||
except Exception as exc:
|
||||
self.api._fail_request_state(spec.request_id, str(exc))
|
||||
raise
|
||||
wait_return_at = time.perf_counter()
|
||||
if job.error is not None:
|
||||
raise RuntimeError(job.error)
|
||||
if job.audio_data is None or job.sample_rate is None or job.result is None:
|
||||
self.api._fail_request_state(spec.request_id, f"{job.request_id} finished without audio result")
|
||||
raise RuntimeError(f"{job.request_id} finished without audio result")
|
||||
pack_start = time.perf_counter()
|
||||
audio_data = pack_audio(BytesIO(), job.audio_data, int(job.sample_rate), job.media_type).getvalue()
|
||||
pack_end = time.perf_counter()
|
||||
pack_ms = (pack_end - pack_start) * 1000.0
|
||||
api_wait_result_ms = 0.0
|
||||
if job.result_ready_time is not None:
|
||||
api_wait_result_ms = max(0.0, (wait_return_at - job.result_ready_time) * 1000.0)
|
||||
response_ready_at = time.perf_counter()
|
||||
response_overhead_ms = max(0.0, (response_ready_at - pack_end) * 1000.0)
|
||||
submit_profile = self.api._build_scheduler_submit_profile(
|
||||
backend="scheduler_v1",
|
||||
request_start=request_start,
|
||||
response_ready_at=response_ready_at,
|
||||
audio_bytes=len(audio_data),
|
||||
sample_rate=int(job.sample_rate),
|
||||
prepare_spec_build_ms=prepare_spec_build_ms,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_executor_queue_ms=prepare_executor_queue_ms,
|
||||
prepare_executor_run_ms=prepare_executor_run_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
prepare_profile_wall_ms=prepare_profile_wall_ms,
|
||||
prepare_other_ms=prepare_other_ms,
|
||||
engine_policy_wait_ms=float(job.result.get("engine_policy_wait_ms", 0.0)),
|
||||
api_after_prepare_ms=api_after_prepare_ms,
|
||||
api_wait_result_ms=api_wait_result_ms,
|
||||
pack_ms=pack_ms,
|
||||
response_overhead_ms=response_overhead_ms,
|
||||
worker_profile=dict(job.result or {}),
|
||||
)
|
||||
headers = self.api._build_scheduler_submit_headers(
|
||||
request_id=job.request_id,
|
||||
media_type=job.media_type,
|
||||
sample_rate=int(job.sample_rate),
|
||||
profile=submit_profile,
|
||||
)
|
||||
self.api._merge_request_state_profile(
|
||||
spec.request_id,
|
||||
dict(submit_profile, response_headers_emitted=True),
|
||||
)
|
||||
return SchedulerSubmitExecution(audio_bytes=audio_data, media_type=str(job.media_type), headers=headers)
|
||||
106
GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py
Normal file
106
GPT_SoVITS/TTS_infer_pack/unified_engine_audio.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import threading
|
||||
import wave
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
|
||||
def set_scheduler_seed(seed: int):
|
||||
if seed in ["", None]:
|
||||
return
|
||||
seed = int(seed)
|
||||
if seed < 0:
|
||||
return
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
def handle_pack_ogg():
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
stack_size = 4096 * 4096
|
||||
try:
|
||||
threading.stack_size(stack_size)
|
||||
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
|
||||
pack_ogg_thread.start()
|
||||
pack_ogg_thread.join()
|
||||
except (RuntimeError, ValueError):
|
||||
handle_pack_ogg()
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer = BytesIO()
|
||||
sf.write(io_buffer, data, rate, format="wav")
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ar",
|
||||
str(rate),
|
||||
"-ac",
|
||||
"1",
|
||||
"-i",
|
||||
"pipe:0",
|
||||
"-c:a",
|
||||
"aac",
|
||||
"-b:a",
|
||||
"192k",
|
||||
"-vn",
|
||||
"-f",
|
||||
"adts",
|
||||
"pipe:1",
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
21
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py
Normal file
21
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge.py
Normal file
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_registry import EngineRegistryBridgeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_runtime import EngineRuntimeBridgeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_stage import EngineStageBridgeFacade
|
||||
|
||||
|
||||
class EngineBridgeFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
self.registry_bridge = EngineRegistryBridgeFacade(owner)
|
||||
self.stage_bridge = EngineStageBridgeFacade(owner)
|
||||
self.runtime_bridge = EngineRuntimeBridgeFacade(owner)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
for component in (self.registry_bridge, self.stage_bridge, self.runtime_bridge):
|
||||
if hasattr(component, name):
|
||||
return getattr(component, name)
|
||||
raise AttributeError(name)
|
||||
202
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py
Normal file
202
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_delegates.py
Normal file
@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SActiveBatch, T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineRequestState, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineBridgeDelegates:
|
||||
def _register_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
api_mode: str,
|
||||
backend: str,
|
||||
media_type: str,
|
||||
response_streaming: bool,
|
||||
deadline_ts: float | None = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
) -> EngineRequestState:
|
||||
return self.bridge_facade._register_request_state(
|
||||
request_id=request_id,
|
||||
api_mode=api_mode,
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=response_streaming,
|
||||
deadline_ts=deadline_ts,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def _update_request_state(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.bridge_facade._update_request_state(request_id, status, extra)
|
||||
|
||||
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.bridge_facade._merge_request_state_profile(request_id, extra)
|
||||
|
||||
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_prepare_state()
|
||||
|
||||
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_finalize_state()
|
||||
|
||||
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_dispatch_state()
|
||||
|
||||
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.bridge_facade._register_engine_job(job)
|
||||
|
||||
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.bridge_facade._get_engine_job(request_id)
|
||||
|
||||
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.bridge_facade._pop_engine_job(request_id)
|
||||
|
||||
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_job_registry()
|
||||
|
||||
def _is_engine_drained(self) -> bool:
|
||||
return self.bridge_facade._is_engine_drained()
|
||||
|
||||
def _record_engine_job_done(self, request_id: str) -> None:
|
||||
self.bridge_facade._record_engine_job_done(request_id)
|
||||
|
||||
def _complete_engine_job(
|
||||
self,
|
||||
job: SchedulerPendingJob,
|
||||
item: T2SFinishedItem,
|
||||
*,
|
||||
sample_rate: int,
|
||||
audio_data: np.ndarray,
|
||||
) -> None:
|
||||
self.bridge_facade._complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
|
||||
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
|
||||
self.bridge_facade._fail_engine_jobs(request_ids, error)
|
||||
|
||||
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
|
||||
self.bridge_facade._add_engine_prefill_time(jobs, elapsed_s)
|
||||
|
||||
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
self.bridge_facade._add_engine_merge_time(request_ids, elapsed_s)
|
||||
|
||||
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
self.bridge_facade._add_engine_decode_time(request_ids, elapsed_s)
|
||||
|
||||
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
|
||||
self.bridge_facade._enqueue_engine_finished_items(items)
|
||||
|
||||
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_decode_pending_queue_state()
|
||||
|
||||
@staticmethod
|
||||
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
|
||||
return EngineBridgeFacade._summarize_active_batch(active_batch)
|
||||
|
||||
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
|
||||
self.bridge_facade._refresh_engine_decode_runtime_state(last_event)
|
||||
|
||||
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
|
||||
self.bridge_facade._update_engine_decode_runtime_state(snapshot)
|
||||
|
||||
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_decode_runtime_state()
|
||||
|
||||
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_engine_arbiter_state()
|
||||
|
||||
def _notify_engine_arbiter(self) -> None:
|
||||
self.bridge_facade._notify_engine_arbiter()
|
||||
|
||||
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.bridge_facade._enqueue_engine_decode_pending_job(job)
|
||||
|
||||
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
return self.bridge_facade._take_engine_decode_pending_jobs_nonblocking(wait_for_batch)
|
||||
|
||||
def _peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
return self.bridge_facade._peek_queue_age_ms(queue_name)
|
||||
|
||||
def _engine_has_pending_work(self) -> bool:
|
||||
return self.bridge_facade._engine_has_pending_work()
|
||||
|
||||
async def _prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.bridge_facade._prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=prepare_submit_at,
|
||||
engine_request_id=engine_request_id,
|
||||
)
|
||||
|
||||
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
self.bridge_facade._enqueue_worker_finished_for_finalize(tasks)
|
||||
|
||||
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
return self.bridge_facade._take_engine_finalize_batch_nonblocking()
|
||||
|
||||
async def _enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
return await self.bridge_facade._enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
|
||||
self.bridge_facade._mark_arbiter_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
|
||||
|
||||
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
||||
return self.bridge_facade._select_engine_stage()
|
||||
|
||||
def _run_engine_prepare_once(self) -> bool:
|
||||
return self.bridge_facade._run_engine_prepare_once()
|
||||
|
||||
def _run_engine_finalize_once(self) -> bool:
|
||||
return self.bridge_facade._run_engine_finalize_once()
|
||||
|
||||
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
|
||||
return self.bridge_facade._run_engine_dispatch_once(policy_snapshot, worker_state)
|
||||
|
||||
def _run_engine_decode_runtime_once(self) -> bool:
|
||||
return self.bridge_facade._run_engine_decode_runtime_once()
|
||||
|
||||
def _run_engine_arbiter_loop(self) -> None:
|
||||
self.bridge_facade._run_engine_arbiter_loop()
|
||||
|
||||
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.bridge_facade._complete_request_state(request_id, extra)
|
||||
|
||||
def _fail_request_state(self, request_id: str, error: str) -> None:
|
||||
self.bridge_facade._fail_request_state(request_id, error)
|
||||
|
||||
def _snapshot_request_registry(self) -> Dict[str, Any]:
|
||||
return self.bridge_facade._snapshot_request_registry()
|
||||
231
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py
Normal file
231
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_registry.py
Normal file
@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineRequestState, EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineRegistryBridgeFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
@property
|
||||
def request_registry(self):
|
||||
return self.owner.request_registry
|
||||
|
||||
@property
|
||||
def engine_prepare_queue_owner(self):
|
||||
return self.owner.engine_prepare_queue_owner
|
||||
|
||||
@property
|
||||
def engine_prepare_text_queue_owner(self):
|
||||
return self.owner.engine_prepare_text_queue_owner
|
||||
|
||||
@property
|
||||
def engine_prepare_ref_spec_queue_owner(self):
|
||||
return self.owner.engine_prepare_ref_spec_queue_owner
|
||||
|
||||
@property
|
||||
def engine_finalize_queue_owner(self):
|
||||
return self.owner.engine_finalize_queue_owner
|
||||
|
||||
@property
|
||||
def engine_dispatch_queue_owner(self):
|
||||
return self.owner.engine_dispatch_queue_owner
|
||||
|
||||
@property
|
||||
def engine_decode_runtime_owner(self):
|
||||
return self.owner.engine_decode_runtime_owner
|
||||
|
||||
@property
|
||||
def engine_job_registry(self):
|
||||
return self.owner.engine_job_registry
|
||||
|
||||
@property
|
||||
def scheduler_worker(self):
|
||||
return self.owner.scheduler_worker
|
||||
|
||||
def _register_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
api_mode: str,
|
||||
backend: str,
|
||||
media_type: str,
|
||||
response_streaming: bool,
|
||||
deadline_ts: float | None = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
) -> EngineRequestState:
|
||||
return self.request_registry.register(
|
||||
request_id=request_id,
|
||||
api_mode=api_mode,
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=response_streaming,
|
||||
deadline_ts=deadline_ts,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def _update_request_state(
|
||||
self,
|
||||
request_id: str,
|
||||
status: str,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.request_registry.update(request_id, status, extra)
|
||||
|
||||
def _merge_request_state_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.request_registry.merge_profile(request_id, extra)
|
||||
|
||||
def _complete_request_state(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.request_registry.complete(request_id, extra)
|
||||
|
||||
def _fail_request_state(self, request_id: str, error: str) -> None:
|
||||
self.request_registry.fail(request_id, error)
|
||||
|
||||
def _snapshot_request_registry(self) -> Dict[str, Any]:
|
||||
return self.request_registry.snapshot()
|
||||
|
||||
def _snapshot_engine_prepare_state(self) -> Dict[str, Any]:
|
||||
audio_snapshot = self.engine_prepare_queue_owner.snapshot(max_request_ids=16)
|
||||
text_snapshot = self.engine_prepare_text_queue_owner.snapshot(max_request_ids=16)
|
||||
ref_spec_snapshot = self.engine_prepare_ref_spec_queue_owner.snapshot(max_request_ids=16)
|
||||
return {
|
||||
"waiting_count": int(audio_snapshot.get("waiting_count", 0))
|
||||
+ int(text_snapshot.get("waiting_count", 0))
|
||||
+ int(ref_spec_snapshot.get("waiting_count", 0)),
|
||||
"audio_waiting_count": int(audio_snapshot.get("waiting_count", 0)),
|
||||
"text_waiting_count": int(text_snapshot.get("waiting_count", 0)),
|
||||
"ref_spec_waiting_count": int(ref_spec_snapshot.get("waiting_count", 0)),
|
||||
"audio_waiting_request_ids": list(audio_snapshot.get("waiting_request_ids", [])),
|
||||
"text_waiting_request_ids": list(text_snapshot.get("waiting_request_ids", [])),
|
||||
"ref_spec_waiting_request_ids": list(ref_spec_snapshot.get("waiting_request_ids", [])),
|
||||
"peak_waiting": int(
|
||||
max(
|
||||
int(audio_snapshot.get("peak_waiting", 0)),
|
||||
int(text_snapshot.get("peak_waiting", 0)),
|
||||
int(ref_spec_snapshot.get("peak_waiting", 0)),
|
||||
)
|
||||
),
|
||||
"total_submitted": int(audio_snapshot.get("total_submitted", 0)),
|
||||
"total_completed": int(audio_snapshot.get("total_completed", 0)),
|
||||
"text_total_submitted": int(text_snapshot.get("total_submitted", 0)),
|
||||
"text_total_completed": int(text_snapshot.get("total_completed", 0)),
|
||||
"ref_spec_total_submitted": int(ref_spec_snapshot.get("total_submitted", 0)),
|
||||
"ref_spec_total_completed": int(ref_spec_snapshot.get("total_completed", 0)),
|
||||
}
|
||||
|
||||
def _snapshot_engine_finalize_state(self) -> Dict[str, Any]:
|
||||
return self.engine_finalize_queue_owner.snapshot(max_request_ids=16)
|
||||
|
||||
def _snapshot_engine_dispatch_state(self) -> Dict[str, Any]:
|
||||
return self.engine_dispatch_queue_owner.snapshot(
|
||||
max_request_ids=16,
|
||||
extra={"last_policy_snapshot": dict(self.owner.engine_dispatch_last_snapshot or {})},
|
||||
)
|
||||
|
||||
def _register_engine_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.engine_job_registry.register(job, keep_job=True)
|
||||
|
||||
def _get_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.engine_job_registry.get(request_id)
|
||||
|
||||
def _pop_engine_job(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
return self.engine_job_registry.pop(request_id)
|
||||
|
||||
def _snapshot_engine_job_registry(self) -> Dict[str, Any]:
|
||||
return self.engine_job_registry.snapshot(max_request_ids=32)
|
||||
|
||||
def _is_engine_drained(self) -> bool:
|
||||
prepare_empty = self.engine_prepare_queue_owner.is_drained()
|
||||
prepare_text_empty = self.engine_prepare_text_queue_owner.is_drained()
|
||||
prepare_ref_spec_empty = self.engine_prepare_ref_spec_queue_owner.is_drained()
|
||||
dispatch_empty = self.engine_dispatch_queue_owner.is_drained()
|
||||
finalize_empty = self.engine_finalize_queue_owner.is_drained()
|
||||
decode_pending_empty = not self.engine_decode_runtime_owner.has_pending_jobs()
|
||||
job_empty = self.engine_job_registry.is_empty()
|
||||
worker_state = self.scheduler_worker.snapshot()
|
||||
return bool(
|
||||
prepare_empty
|
||||
and prepare_text_empty
|
||||
and prepare_ref_spec_empty
|
||||
and dispatch_empty
|
||||
and finalize_empty
|
||||
and decode_pending_empty
|
||||
and job_empty
|
||||
and self.engine_decode_runtime_owner.get_active_batch() is None
|
||||
and int(worker_state.get("prepare_inflight", 0)) <= 0
|
||||
and int(worker_state.get("finalize_inflight", 0)) <= 0
|
||||
and int(worker_state.get("finalize_pending", 0)) <= 0
|
||||
)
|
||||
|
||||
def _record_engine_job_done(self, request_id: str) -> None:
|
||||
self.engine_job_registry.mark_finished_and_remove(request_id)
|
||||
self.scheduler_worker.record_external_job_done(request_id)
|
||||
|
||||
def _complete_engine_job(
|
||||
self,
|
||||
job: SchedulerPendingJob,
|
||||
item: T2SFinishedItem,
|
||||
*,
|
||||
sample_rate: int,
|
||||
audio_data: np.ndarray,
|
||||
) -> None:
|
||||
completion_bridge = self.scheduler_worker.completion_bridge
|
||||
completion_bridge.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
completion_bridge.complete_job(
|
||||
job,
|
||||
runtime_request_id=job.engine_request_id,
|
||||
runtime_extra=completion_bridge.build_runtime_complete_payload(job, item, sample_rate=sample_rate),
|
||||
on_job_finished=lambda rid=item.request_id: self._record_engine_job_done(rid),
|
||||
)
|
||||
|
||||
def _fail_engine_jobs(self, request_ids: List[str], error: str) -> None:
|
||||
if not request_ids:
|
||||
return
|
||||
completion_bridge = self.scheduler_worker.completion_bridge
|
||||
for request_id in request_ids:
|
||||
job = self._get_engine_job(request_id)
|
||||
if job is None:
|
||||
continue
|
||||
completion_bridge.fail_job(
|
||||
job,
|
||||
error=error,
|
||||
on_job_finished=lambda rid=request_id: self._record_engine_job_done(rid),
|
||||
)
|
||||
|
||||
def _add_engine_prefill_time(self, jobs: List[SchedulerPendingJob], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
for job in jobs:
|
||||
job.prefill_ms += delta_ms
|
||||
|
||||
def _add_engine_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
for request_id in request_ids:
|
||||
job = self._get_engine_job(request_id)
|
||||
if job is not None:
|
||||
job.merge_ms += delta_ms
|
||||
|
||||
def _add_engine_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
activate_request_ids: List[str] = []
|
||||
for request_id in request_ids:
|
||||
job = self._get_engine_job(request_id)
|
||||
if job is None:
|
||||
continue
|
||||
if job.decode_steps == 0:
|
||||
activate_request_ids.append(job.engine_request_id)
|
||||
job.decode_ms += delta_ms
|
||||
job.decode_steps += 1
|
||||
for engine_request_id in activate_request_ids:
|
||||
self._update_request_state(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
|
||||
|
||||
def _enqueue_engine_finished_items(self, items: List[T2SFinishedItem]) -> None:
|
||||
if not items:
|
||||
return
|
||||
enqueued_at = time.perf_counter()
|
||||
tasks = [SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at) for item in items]
|
||||
self.owner.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)
|
||||
33
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py
Normal file
33
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_runtime.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner
|
||||
|
||||
|
||||
class EngineRuntimeBridgeFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
@property
|
||||
def engine_policy_arbiter(self):
|
||||
return self.owner.engine_policy_arbiter
|
||||
|
||||
@staticmethod
|
||||
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
|
||||
return EngineDecodeRuntimeOwner.summarize_active_batch(active_batch)
|
||||
|
||||
def _snapshot_engine_arbiter_state(self) -> Dict[str, Any]:
|
||||
return self.engine_policy_arbiter.snapshot_state()
|
||||
|
||||
def _notify_engine_arbiter(self) -> None:
|
||||
self.engine_policy_arbiter.notify()
|
||||
|
||||
def _mark_arbiter_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
|
||||
self.engine_policy_arbiter.mark_tick(stage=stage, reason=reason, policy_allowed=policy_allowed)
|
||||
|
||||
def _select_engine_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
||||
stage, reason, policy_snapshot, worker_state = self.engine_policy_arbiter.select_stage()
|
||||
self.owner.engine_dispatch_last_snapshot = dict(policy_snapshot)
|
||||
return stage, reason, policy_snapshot, worker_state
|
||||
116
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py
Normal file
116
GPT_SoVITS/TTS_infer_pack/unified_engine_bridge_stage.py
Normal file
@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineStageBridgeFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
@property
|
||||
def engine_decode_runtime_owner(self):
|
||||
return self.owner.engine_decode_runtime_owner
|
||||
|
||||
@property
|
||||
def scheduler_worker(self):
|
||||
return self.owner.scheduler_worker
|
||||
|
||||
@property
|
||||
def engine_stage_coordinator(self):
|
||||
return self.owner.engine_stage_coordinator
|
||||
|
||||
def _snapshot_engine_decode_pending_queue_state(self) -> Dict[str, Any]:
|
||||
return self.engine_decode_runtime_owner.snapshot_pending_queue_state()
|
||||
|
||||
def _refresh_engine_decode_runtime_state(self, last_event: str) -> None:
|
||||
self.engine_decode_runtime_owner.refresh_state(last_event)
|
||||
|
||||
def _update_engine_decode_runtime_state(self, snapshot: Dict[str, Any]) -> None:
|
||||
if not snapshot:
|
||||
return
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
return
|
||||
self.engine_decode_runtime_owner.update_from_worker_snapshot(snapshot)
|
||||
|
||||
def _snapshot_engine_decode_runtime_state(self) -> Dict[str, Any]:
|
||||
return self.engine_decode_runtime_owner.snapshot_state()
|
||||
|
||||
def _enqueue_engine_decode_pending_job(self, job: SchedulerPendingJob) -> None:
|
||||
self.engine_decode_runtime_owner.enqueue_pending_job(job)
|
||||
self.owner.engine_policy_arbiter.notify()
|
||||
|
||||
def _take_engine_decode_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
return self.engine_decode_runtime_owner.take_pending_jobs_nonblocking(wait_for_batch)
|
||||
|
||||
def _peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
return self.engine_stage_coordinator.peek_queue_age_ms(queue_name)
|
||||
|
||||
def _engine_has_pending_work(self) -> bool:
|
||||
return self.engine_stage_coordinator.has_pending_work()
|
||||
|
||||
async def _prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.engine_stage_coordinator.prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=prepare_submit_at,
|
||||
engine_request_id=engine_request_id,
|
||||
)
|
||||
|
||||
def _enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
self.engine_stage_coordinator.enqueue_worker_finished_for_finalize(tasks)
|
||||
|
||||
def _take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
return self.engine_stage_coordinator.take_engine_finalize_batch_nonblocking()
|
||||
|
||||
async def _enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
return await self.engine_stage_coordinator.enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def _run_engine_prepare_once(self) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_prepare_once()
|
||||
|
||||
def _run_engine_finalize_once(self) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_finalize_once()
|
||||
|
||||
def _run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_dispatch_once(policy_snapshot, worker_state)
|
||||
|
||||
def _run_engine_decode_runtime_once(self) -> bool:
|
||||
return self.engine_stage_coordinator.run_engine_decode_runtime_once()
|
||||
|
||||
def _run_engine_arbiter_loop(self) -> None:
|
||||
return self.engine_stage_coordinator.run_engine_arbiter_loop()
|
||||
183
GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py
Normal file
183
GPT_SoVITS/TTS_infer_pack/unified_engine_builder.py
Normal file
@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api import EngineApiFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge import EngineBridgeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
|
||||
EngineArbiterConfig,
|
||||
EngineDecodeRuntimeOwner,
|
||||
EnginePolicyArbiterController,
|
||||
EnginePolicyConfig,
|
||||
EngineRequestRegistry,
|
||||
EngineTaskQueueOwner,
|
||||
ModelRegistry,
|
||||
ReferenceRegistry,
|
||||
RuntimeStateCallbacks,
|
||||
SchedulerJobRegistry,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage import EngineStageCoordinator
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
|
||||
|
||||
|
||||
class EngineCompositionBuilder:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
def build(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
|
||||
self._init_registries_and_locks()
|
||||
self._init_worker(max_steps=max_steps, micro_batch_wait_ms=micro_batch_wait_ms)
|
||||
self._init_policy_configs(micro_batch_wait_ms=micro_batch_wait_ms)
|
||||
self._init_runtime_owners()
|
||||
self._init_stage_coordinator()
|
||||
self._init_arbiter()
|
||||
self._init_facades()
|
||||
self._start_arbiter_thread()
|
||||
|
||||
def _init_registries_and_locks(self) -> None:
|
||||
owner = self.owner
|
||||
owner.reference_registry = ReferenceRegistry()
|
||||
owner.model_registry = ModelRegistry(
|
||||
t2s_weights_path=str(owner.tts.configs.t2s_weights_path),
|
||||
vits_weights_path=str(owner.tts.configs.vits_weights_path),
|
||||
)
|
||||
owner.request_registry = EngineRequestRegistry(
|
||||
recent_limit=max(1, int(os.environ.get("GPTSOVITS_ENGINE_RECENT_REQUEST_LIMIT", "64")))
|
||||
)
|
||||
owner.engine_job_registry = SchedulerJobRegistry(threading.Lock())
|
||||
owner.direct_tts_lock = threading.RLock()
|
||||
owner.management_lock = threading.RLock()
|
||||
owner.engine_dispatch_last_snapshot = {}
|
||||
|
||||
def _init_worker(self, *, max_steps: int, micro_batch_wait_ms: int) -> None:
|
||||
owner = self.owner
|
||||
owner.scheduler_worker = UnifiedSchedulerWorker(
|
||||
owner.tts,
|
||||
max_steps=max_steps,
|
||||
micro_batch_wait_ms=micro_batch_wait_ms,
|
||||
runtime_callbacks=RuntimeStateCallbacks(
|
||||
update=owner._update_request_state,
|
||||
complete=owner._complete_request_state,
|
||||
fail=owner._fail_request_state,
|
||||
decode_runtime_update=owner._update_engine_decode_runtime_state,
|
||||
),
|
||||
external_finalize_submit=owner._enqueue_worker_finished_for_finalize,
|
||||
)
|
||||
|
||||
def _init_policy_configs(self, *, micro_batch_wait_ms: int) -> None:
|
||||
owner = self.owner
|
||||
worker_capacity_limits = owner.scheduler_worker.get_capacity_limits()
|
||||
prepare_max_inflight = int(owner.scheduler_worker.get_prepare_max_inflight())
|
||||
owner.engine_policy_config = EnginePolicyConfig(
|
||||
enabled=owner._env_flag("GPTSOVITS_ENGINE_POLICY_ENABLE", True),
|
||||
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_POLICY_POLL_WAIT_MS", float(micro_batch_wait_ms))),
|
||||
decode_backlog_soft_max=max(
|
||||
0,
|
||||
owner._env_int(
|
||||
"GPTSOVITS_ENGINE_POLICY_DECODE_BACKLOG_SOFT_MAX",
|
||||
int(worker_capacity_limits["decode_backlog_max"]),
|
||||
),
|
||||
),
|
||||
finalize_pending_soft_max=max(
|
||||
0,
|
||||
owner._env_int(
|
||||
"GPTSOVITS_ENGINE_POLICY_FINALIZE_PENDING_SOFT_MAX",
|
||||
int(worker_capacity_limits["finalize_pending_max"]),
|
||||
),
|
||||
),
|
||||
prepare_inflight_soft_max=max(
|
||||
0,
|
||||
owner._env_int("GPTSOVITS_ENGINE_POLICY_PREPARE_INFLIGHT_SOFT_MAX", prepare_max_inflight),
|
||||
),
|
||||
active_decode_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_DECODE_SOFT_MAX", 0)),
|
||||
ready_for_prefill_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_READY_FOR_PREFILL_SOFT_MAX", 0)),
|
||||
active_request_soft_max=max(0, owner._env_int("GPTSOVITS_ENGINE_POLICY_ACTIVE_REQUEST_SOFT_MAX", 0)),
|
||||
)
|
||||
owner.engine_arbiter_config = EngineArbiterConfig(
|
||||
poll_wait_ms=max(1.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_POLL_WAIT_MS", float(micro_batch_wait_ms))),
|
||||
decode_burst=max(1, owner._env_int("GPTSOVITS_ENGINE_ARBITER_DECODE_BURST", 4)),
|
||||
prepare_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_PREPARE_AGING_MS", 10.0)),
|
||||
finalize_aging_ms=max(0.0, owner._env_float("GPTSOVITS_ENGINE_ARBITER_FINALIZE_AGING_MS", 10.0)),
|
||||
)
|
||||
|
||||
def _init_runtime_owners(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_decode_runtime_owner = EngineDecodeRuntimeOwner(
|
||||
get_decode_runtime_counters=owner.scheduler_worker.get_decode_runtime_counters,
|
||||
get_micro_batch_wait_s=owner.scheduler_worker.get_micro_batch_wait_s,
|
||||
)
|
||||
owner.engine_prepare_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_prepare_text_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_prepare_ref_spec_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_finalize_queue_owner = EngineTaskQueueOwner(completion_key="total_completed")
|
||||
owner.engine_dispatch_queue_owner = EngineTaskQueueOwner(completion_key="total_dispatched")
|
||||
|
||||
def _init_stage_coordinator(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_stage_coordinator = EngineStageCoordinator(
|
||||
tts=owner.tts,
|
||||
scheduler_worker=owner.scheduler_worker,
|
||||
prepare_queue_owner=owner.engine_prepare_queue_owner,
|
||||
prepare_text_queue_owner=owner.engine_prepare_text_queue_owner,
|
||||
prepare_ref_spec_queue_owner=owner.engine_prepare_ref_spec_queue_owner,
|
||||
finalize_queue_owner=owner.engine_finalize_queue_owner,
|
||||
dispatch_queue_owner=owner.engine_dispatch_queue_owner,
|
||||
decode_runtime_owner=owner.engine_decode_runtime_owner,
|
||||
update_request_state=owner._update_request_state,
|
||||
merge_request_state_profile=owner._merge_request_state_profile,
|
||||
fail_request_state=owner._fail_request_state,
|
||||
get_engine_job=owner._get_engine_job,
|
||||
register_engine_job=owner._register_engine_job,
|
||||
fail_engine_jobs=owner._fail_engine_jobs,
|
||||
complete_engine_job=owner._complete_engine_job,
|
||||
add_engine_prefill_time=owner._add_engine_prefill_time,
|
||||
add_engine_merge_time=owner._add_engine_merge_time,
|
||||
add_engine_decode_time=owner._add_engine_decode_time,
|
||||
enqueue_engine_finished_items=owner._enqueue_engine_finished_items,
|
||||
snapshot_engine_dispatch_state=owner._snapshot_engine_dispatch_state,
|
||||
snapshot_engine_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
|
||||
)
|
||||
|
||||
def _init_arbiter(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_policy_arbiter = EnginePolicyArbiterController(
|
||||
policy_config=owner.engine_policy_config,
|
||||
arbiter_config=owner.engine_arbiter_config,
|
||||
snapshot_request_registry=owner._snapshot_request_registry,
|
||||
get_worker_state=owner.scheduler_worker.snapshot,
|
||||
snapshot_prepare_state=owner._snapshot_engine_prepare_state,
|
||||
snapshot_finalize_state=owner._snapshot_engine_finalize_state,
|
||||
snapshot_dispatch_state=owner._snapshot_engine_dispatch_state,
|
||||
snapshot_decode_runtime_state=owner._snapshot_engine_decode_runtime_state,
|
||||
snapshot_job_registry=owner._snapshot_engine_job_registry,
|
||||
peek_queue_age_ms=owner.engine_stage_coordinator.peek_queue_age_ms,
|
||||
merge_request_state_profile=owner._merge_request_state_profile,
|
||||
)
|
||||
owner.engine_stage_coordinator.bind_arbiter(
|
||||
notify_arbiter=owner._notify_engine_arbiter,
|
||||
select_stage=owner._select_engine_stage,
|
||||
mark_arbiter_tick=lambda stage, reason, policy_allowed: owner._mark_arbiter_tick(
|
||||
stage=stage,
|
||||
reason=reason,
|
||||
policy_allowed=policy_allowed,
|
||||
),
|
||||
wait_arbiter=owner.engine_policy_arbiter.wait,
|
||||
)
|
||||
|
||||
def _init_facades(self) -> None:
|
||||
owner = self.owner
|
||||
owner.bridge_facade = EngineBridgeFacade(owner)
|
||||
owner.api_facade = EngineApiFacade(owner)
|
||||
owner.runtime_facade = EngineRuntimeFacade(owner)
|
||||
|
||||
def _start_arbiter_thread(self) -> None:
|
||||
owner = self.owner
|
||||
owner.engine_arbiter_thread = threading.Thread(
|
||||
target=owner._run_engine_arbiter_loop,
|
||||
name="unified-engine-arbiter",
|
||||
daemon=True,
|
||||
)
|
||||
owner.engine_arbiter_thread.start()
|
||||
121
GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py
Normal file
121
GPT_SoVITS/TTS_infer_pack/unified_engine_component_models.py
Normal file
@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeControlCallbacks:
|
||||
restart: Callable[[], None] | None = None
|
||||
exit: Callable[[], None] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectTTSExecution:
|
||||
media_type: str
|
||||
streaming: bool
|
||||
audio_generator: Optional[Generator[bytes, None, None]] = None
|
||||
audio_bytes: Optional[bytes] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalizedEngineRequest:
|
||||
request_id: str
|
||||
text: str
|
||||
text_lang: str
|
||||
ref_audio_path: str
|
||||
prompt_lang: str
|
||||
prompt_text: str = ""
|
||||
aux_ref_audio_paths: List[str] | None = None
|
||||
top_k: int = 15
|
||||
top_p: float = 1.0
|
||||
temperature: float = 1.0
|
||||
repetition_penalty: float = 1.35
|
||||
early_stop_num: int = -1
|
||||
ready_step: int = 0
|
||||
text_split_method: str = "cut5"
|
||||
batch_size: int = 1
|
||||
batch_threshold: float = 0.75
|
||||
split_bucket: bool = False
|
||||
speed_factor: float = 1.0
|
||||
fragment_interval: float = 0.3
|
||||
seed: int = -1
|
||||
media_type: str = "wav"
|
||||
streaming_mode: bool | int = False
|
||||
return_fragment: bool = False
|
||||
fixed_length_chunk: bool = False
|
||||
response_streaming: bool = False
|
||||
parallel_infer: bool = False
|
||||
sample_steps: int = 32
|
||||
super_sampling: bool = False
|
||||
overlap_length: int = 2
|
||||
min_chunk_length: int = 16
|
||||
timeout_sec: float | None = None
|
||||
|
||||
def to_payload(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"text": self.text,
|
||||
"text_lang": self.text_lang,
|
||||
"ref_audio_path": self.ref_audio_path,
|
||||
"aux_ref_audio_paths": list(self.aux_ref_audio_paths) if self.aux_ref_audio_paths else None,
|
||||
"prompt_text": self.prompt_text,
|
||||
"prompt_lang": self.prompt_lang,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"text_split_method": self.text_split_method,
|
||||
"batch_size": self.batch_size,
|
||||
"batch_threshold": self.batch_threshold,
|
||||
"speed_factor": self.speed_factor,
|
||||
"split_bucket": self.split_bucket,
|
||||
"fragment_interval": self.fragment_interval,
|
||||
"seed": self.seed,
|
||||
"media_type": self.media_type,
|
||||
"streaming_mode": self.streaming_mode,
|
||||
"return_fragment": self.return_fragment,
|
||||
"fixed_length_chunk": self.fixed_length_chunk,
|
||||
"response_streaming": self.response_streaming,
|
||||
"parallel_infer": self.parallel_infer,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"sample_steps": self.sample_steps,
|
||||
"super_sampling": self.super_sampling,
|
||||
"overlap_length": self.overlap_length,
|
||||
"min_chunk_length": self.min_chunk_length,
|
||||
"early_stop_num": self.early_stop_num,
|
||||
"ready_step": self.ready_step,
|
||||
"timeout_sec": self.timeout_sec,
|
||||
}
|
||||
|
||||
def to_scheduler_spec(self) -> SchedulerRequestSpec:
|
||||
return SchedulerRequestSpec(
|
||||
request_id=self.request_id,
|
||||
ref_audio_path=Path(self.ref_audio_path),
|
||||
prompt_text=self.prompt_text,
|
||||
prompt_lang=self.prompt_lang,
|
||||
text=self.text,
|
||||
text_lang=self.text_lang,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
temperature=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
early_stop_num=self.early_stop_num,
|
||||
aux_ref_audio_paths=list(self.aux_ref_audio_paths or []),
|
||||
ready_step=self.ready_step,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerDebugExecution:
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerSubmitExecution:
|
||||
audio_bytes: bytes
|
||||
media_type: str
|
||||
headers: Dict[str, str]
|
||||
363
GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py
Normal file
363
GPT_SoVITS/TTS_infer_pack/unified_engine_component_policy.py
Normal file
@ -0,0 +1,363 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import EngineStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnginePolicyConfig:
|
||||
enabled: bool = True
|
||||
poll_wait_ms: float = 5.0
|
||||
decode_backlog_soft_max: int = 0
|
||||
finalize_pending_soft_max: int = 0
|
||||
prepare_inflight_soft_max: int = 0
|
||||
active_decode_soft_max: int = 0
|
||||
ready_for_prefill_soft_max: int = 0
|
||||
active_request_soft_max: int = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"enabled": bool(self.enabled),
|
||||
"poll_wait_ms": float(self.poll_wait_ms),
|
||||
"decode_backlog_soft_max": int(self.decode_backlog_soft_max),
|
||||
"finalize_pending_soft_max": int(self.finalize_pending_soft_max),
|
||||
"prepare_inflight_soft_max": int(self.prepare_inflight_soft_max),
|
||||
"active_decode_soft_max": int(self.active_decode_soft_max),
|
||||
"ready_for_prefill_soft_max": int(self.ready_for_prefill_soft_max),
|
||||
"active_request_soft_max": int(self.active_request_soft_max),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineArbiterConfig:
|
||||
poll_wait_ms: float = 5.0
|
||||
decode_burst: int = 4
|
||||
prepare_aging_ms: float = 10.0
|
||||
finalize_aging_ms: float = 10.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"poll_wait_ms": float(self.poll_wait_ms),
|
||||
"decode_burst": int(self.decode_burst),
|
||||
"prepare_aging_ms": float(self.prepare_aging_ms),
|
||||
"finalize_aging_ms": float(self.finalize_aging_ms),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineArbiterState:
|
||||
total_ticks: int = 0
|
||||
total_idle_ticks: int = 0
|
||||
total_prepare_dispatches: int = 0
|
||||
total_decode_dispatches: int = 0
|
||||
total_decode_runtime_ticks: int = 0
|
||||
total_finalize_dispatches: int = 0
|
||||
decode_budget_remaining: int = 0
|
||||
last_stage: str = "idle"
|
||||
last_reason: str = "init"
|
||||
last_observed_at: float = 0.0
|
||||
last_policy_allowed: bool = True
|
||||
|
||||
|
||||
class EnginePolicyArbiterController:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
policy_config: EnginePolicyConfig,
|
||||
arbiter_config: EngineArbiterConfig,
|
||||
snapshot_request_registry: Callable[[], Dict[str, Any]],
|
||||
get_worker_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_prepare_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_finalize_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_dispatch_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_decode_runtime_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_job_registry: Callable[[], Dict[str, Any]],
|
||||
peek_queue_age_ms: Callable[[str], float],
|
||||
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
|
||||
) -> None:
|
||||
self.policy_config = policy_config
|
||||
self.policy_poll_s = max(0.001, float(self.policy_config.poll_wait_ms) / 1000.0)
|
||||
self.arbiter_config = arbiter_config
|
||||
self.arbiter_poll_s = max(0.001, float(self.arbiter_config.poll_wait_ms) / 1000.0)
|
||||
self.condition = threading.Condition()
|
||||
self.state = EngineArbiterState(
|
||||
decode_budget_remaining=int(self.arbiter_config.decode_burst),
|
||||
last_observed_at=time.perf_counter(),
|
||||
)
|
||||
self.snapshot_request_registry = snapshot_request_registry
|
||||
self.get_worker_state = get_worker_state
|
||||
self.snapshot_prepare_state = snapshot_prepare_state
|
||||
self.snapshot_finalize_state = snapshot_finalize_state
|
||||
self.snapshot_dispatch_state = snapshot_dispatch_state
|
||||
self.snapshot_decode_runtime_state = snapshot_decode_runtime_state
|
||||
self.snapshot_job_registry = snapshot_job_registry
|
||||
self.peek_queue_age_ms = peek_queue_age_ms
|
||||
self.merge_request_state_profile = merge_request_state_profile
|
||||
|
||||
def snapshot_state(self) -> Dict[str, Any]:
|
||||
with self.condition:
|
||||
return {
|
||||
"config": self.arbiter_config.to_dict(),
|
||||
"total_ticks": int(self.state.total_ticks),
|
||||
"total_idle_ticks": int(self.state.total_idle_ticks),
|
||||
"total_prepare_dispatches": int(self.state.total_prepare_dispatches),
|
||||
"total_decode_dispatches": int(self.state.total_decode_dispatches),
|
||||
"total_decode_runtime_ticks": int(self.state.total_decode_runtime_ticks),
|
||||
"total_finalize_dispatches": int(self.state.total_finalize_dispatches),
|
||||
"decode_budget_remaining": int(self.state.decode_budget_remaining),
|
||||
"last_stage": str(self.state.last_stage),
|
||||
"last_reason": str(self.state.last_reason),
|
||||
"last_policy_allowed": bool(self.state.last_policy_allowed),
|
||||
"last_observed_at": float(self.state.last_observed_at),
|
||||
}
|
||||
|
||||
def notify(self) -> None:
|
||||
with self.condition:
|
||||
self.condition.notify_all()
|
||||
|
||||
def wait(self) -> None:
|
||||
with self.condition:
|
||||
self.condition.wait(timeout=self.arbiter_poll_s)
|
||||
|
||||
def mark_tick(self, *, stage: str, reason: str, policy_allowed: bool) -> None:
|
||||
with self.condition:
|
||||
self.state.total_ticks += 1
|
||||
if stage == "idle":
|
||||
self.state.total_idle_ticks += 1
|
||||
elif stage in {"prepare", "prepare_audio", "prepare_text", "prepare_ref_spec"}:
|
||||
self.state.total_prepare_dispatches += 1
|
||||
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
|
||||
elif stage == "finalize":
|
||||
self.state.total_finalize_dispatches += 1
|
||||
self.state.decode_budget_remaining = int(self.arbiter_config.decode_burst)
|
||||
elif stage == "decode_dispatch":
|
||||
self.state.total_decode_dispatches += 1
|
||||
elif stage == "decode_runtime":
|
||||
self.state.total_decode_runtime_ticks += 1
|
||||
self.state.decode_budget_remaining = max(0, int(self.state.decode_budget_remaining) - 1)
|
||||
self.state.last_stage = str(stage)
|
||||
self.state.last_reason = str(reason)
|
||||
self.state.last_policy_allowed = bool(policy_allowed)
|
||||
self.state.last_observed_at = time.perf_counter()
|
||||
|
||||
def build_stage_counters(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
prepare_dispatcher_state = self.snapshot_prepare_state()
|
||||
finalize_dispatcher_state = self.snapshot_finalize_state()
|
||||
dispatcher_state = self.snapshot_dispatch_state()
|
||||
active_requests = list(request_registry.get("active_requests", []))
|
||||
status_counts: Dict[str, int] = {}
|
||||
for item in active_requests:
|
||||
status = str(item.get("status", "UNKNOWN"))
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
worker_pending_jobs = int(worker_state.get("pending_jobs", 0))
|
||||
worker_decode_active_size = int(worker_state.get("running_requests", 0))
|
||||
worker_prepare_inflight = int(worker_state.get("prepare_inflight", 0))
|
||||
worker_finalize_pending = int(worker_state.get("finalize_pending", 0))
|
||||
worker_finalize_inflight = int(worker_state.get("finalize_inflight", 0))
|
||||
engine_decode_runtime_state = self.snapshot_decode_runtime_state()
|
||||
engine_job_registry = self.snapshot_job_registry()
|
||||
decode_runtime_pending_jobs = int(engine_decode_runtime_state.get("pending_jobs", 0))
|
||||
decode_runtime_active_size = int(engine_decode_runtime_state.get("active_request_count", 0))
|
||||
return {
|
||||
"active_request_count": int(len(active_requests)),
|
||||
"status_counts": status_counts,
|
||||
"queued_request_count": int(status_counts.get(EngineStatus.QUEUED, 0)),
|
||||
"cpu_prepare_request_count": int(status_counts.get(EngineStatus.CPU_PREPARING, 0)),
|
||||
"gpu_prepare_request_count": int(status_counts.get(EngineStatus.GPU_PREPARING, 0)),
|
||||
"ready_for_prefill_request_count": int(status_counts.get(EngineStatus.READY_FOR_PREFILL, 0)),
|
||||
"active_decode_request_count": int(status_counts.get(EngineStatus.ACTIVE_DECODE, 0)),
|
||||
"ready_for_finalize_request_count": int(status_counts.get(EngineStatus.READY_FOR_FINALIZE, 0)),
|
||||
"finalizing_request_count": int(status_counts.get(EngineStatus.FINALIZING, 0)),
|
||||
"streaming_request_count": int(status_counts.get(EngineStatus.STREAMING, 0)),
|
||||
"worker_pending_jobs": worker_pending_jobs,
|
||||
"worker_decode_active_size": worker_decode_active_size,
|
||||
"worker_decode_control_enabled": bool(worker_state.get("engine_decode_control_enabled", False)),
|
||||
"worker_decode_runtime_has_work": bool(worker_state.get("decode_runtime_has_work", False)),
|
||||
"engine_decode_runtime_pending_jobs": decode_runtime_pending_jobs,
|
||||
"engine_decode_runtime_active_request_count": decode_runtime_active_size,
|
||||
"engine_decode_runtime_has_work": bool(engine_decode_runtime_state.get("has_work", False)),
|
||||
"engine_job_registry_count": int(engine_job_registry.get("job_count", 0)),
|
||||
"worker_prepare_inflight": worker_prepare_inflight,
|
||||
"worker_finalize_pending": worker_finalize_pending,
|
||||
"worker_finalize_inflight": worker_finalize_inflight,
|
||||
"engine_gpu_prepare_queue_count": int(prepare_dispatcher_state.get("waiting_count", 0)),
|
||||
"engine_finalize_queue_count": int(finalize_dispatcher_state.get("waiting_count", 0)),
|
||||
"engine_decode_waiting_queue_count": int(dispatcher_state.get("waiting_count", 0)),
|
||||
"decode_backlog": int(
|
||||
decode_runtime_pending_jobs + decode_runtime_active_size
|
||||
if bool(worker_state.get("engine_decode_control_enabled", False))
|
||||
else worker_pending_jobs + worker_decode_active_size
|
||||
),
|
||||
}
|
||||
|
||||
def build_policy_snapshot(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
counters = self.build_stage_counters(request_registry, worker_state)
|
||||
config = self.policy_config.to_dict()
|
||||
blocked_reasons: List[Dict[str, Any]] = []
|
||||
finalize_pending_total = int(counters["worker_finalize_pending"]) + int(counters.get("engine_finalize_queue_count", 0))
|
||||
limit_checks = [
|
||||
("decode_backlog", counters["decode_backlog"], int(config["decode_backlog_soft_max"])),
|
||||
("finalize_pending", finalize_pending_total, int(config["finalize_pending_soft_max"])),
|
||||
("prepare_inflight", counters["worker_prepare_inflight"], int(config["prepare_inflight_soft_max"])),
|
||||
("active_decode_requests", counters["active_decode_request_count"], int(config["active_decode_soft_max"])),
|
||||
("ready_for_prefill_requests", counters["ready_for_prefill_request_count"], int(config["ready_for_prefill_soft_max"])),
|
||||
("active_requests", counters["active_request_count"], int(config["active_request_soft_max"])),
|
||||
]
|
||||
if bool(config["enabled"]):
|
||||
for name, value, limit in limit_checks:
|
||||
if limit > 0 and int(value) >= int(limit):
|
||||
blocked_reasons.append({"metric": name, "value": int(value), "limit": int(limit)})
|
||||
return {
|
||||
"enabled": bool(config["enabled"]),
|
||||
"allowed": (not bool(config["enabled"])) or not blocked_reasons,
|
||||
"blocked_reasons": blocked_reasons,
|
||||
"config": config,
|
||||
"metrics": {
|
||||
"active_request_count": int(counters["active_request_count"]),
|
||||
"queued_request_count": int(counters["queued_request_count"]),
|
||||
"ready_for_prefill_request_count": int(counters["ready_for_prefill_request_count"]),
|
||||
"active_decode_request_count": int(counters["active_decode_request_count"]),
|
||||
"engine_gpu_prepare_queue_count": int(counters["engine_gpu_prepare_queue_count"]),
|
||||
"engine_decode_waiting_queue_count": int(counters["engine_decode_waiting_queue_count"]),
|
||||
"decode_backlog": int(counters["decode_backlog"]),
|
||||
"prepare_inflight": int(counters["worker_prepare_inflight"]),
|
||||
"finalize_pending": int(finalize_pending_total),
|
||||
"engine_finalize_queue_count": int(counters.get("engine_finalize_queue_count", 0)),
|
||||
"finalize_inflight": int(counters["worker_finalize_inflight"]),
|
||||
},
|
||||
"observed_at": time.perf_counter(),
|
||||
}
|
||||
|
||||
async def wait_for_policy_admission(
|
||||
self,
|
||||
*,
|
||||
request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
request_registry = self.snapshot_request_registry()
|
||||
worker_state = self.get_worker_state()
|
||||
snapshot = self.build_policy_snapshot(request_registry, worker_state)
|
||||
if not self.policy_config.enabled:
|
||||
return 0.0, snapshot
|
||||
start = time.perf_counter()
|
||||
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
|
||||
while True:
|
||||
request_registry = self.snapshot_request_registry()
|
||||
worker_state = self.get_worker_state()
|
||||
snapshot = self.build_policy_snapshot(request_registry, worker_state)
|
||||
if snapshot["allowed"]:
|
||||
wait_ms = max(0.0, (time.perf_counter() - start) * 1000.0)
|
||||
if request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(request_id),
|
||||
{
|
||||
"engine_policy_wait_ms": float(wait_ms),
|
||||
"engine_policy_snapshot": snapshot,
|
||||
},
|
||||
)
|
||||
return wait_ms, snapshot
|
||||
now = time.perf_counter()
|
||||
if deadline is not None and now >= deadline:
|
||||
blocked_summary = ", ".join(
|
||||
f"{item['metric']}={item['value']}/{item['limit']}" for item in snapshot.get("blocked_reasons", [])
|
||||
)
|
||||
raise TimeoutError(f"engine policy admission timeout ({blocked_summary})")
|
||||
await asyncio.sleep(self.policy_poll_s)
|
||||
|
||||
def select_stage(self) -> tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
||||
request_registry = self.snapshot_request_registry()
|
||||
worker_state = self.get_worker_state()
|
||||
policy_snapshot = self.build_policy_snapshot(request_registry, worker_state)
|
||||
prepare_state = self.snapshot_prepare_state()
|
||||
prepare_waiting = int(prepare_state.get("waiting_count", 0))
|
||||
prepare_audio_waiting = int(prepare_state.get("audio_waiting_count", 0))
|
||||
prepare_text_waiting = int(prepare_state.get("text_waiting_count", 0))
|
||||
prepare_ref_spec_waiting = int(prepare_state.get("ref_spec_waiting_count", 0))
|
||||
finalize_waiting = int(self.snapshot_finalize_state().get("waiting_count", 0))
|
||||
decode_waiting = int(self.snapshot_dispatch_state().get("waiting_count", 0))
|
||||
decode_runtime_state = self.snapshot_decode_runtime_state()
|
||||
worker_decode_has_work = bool(decode_runtime_state.get("has_work", False))
|
||||
worker_decode_control_enabled = bool(worker_state.get("engine_decode_control_enabled", False))
|
||||
worker_pending_jobs = int(decode_runtime_state.get("pending_jobs", 0))
|
||||
worker_running_requests = int(decode_runtime_state.get("active_request_count", 0))
|
||||
prepare_age_ms = float(self.peek_queue_age_ms("prepare"))
|
||||
prepare_audio_age_ms = float(self.peek_queue_age_ms("prepare_audio"))
|
||||
prepare_text_age_ms = float(self.peek_queue_age_ms("prepare_text"))
|
||||
prepare_ref_spec_age_ms = float(self.peek_queue_age_ms("prepare_ref_spec"))
|
||||
finalize_age_ms = float(self.peek_queue_age_ms("finalize"))
|
||||
decode_runtime_pending_age_ms = float(self.peek_queue_age_ms("decode_runtime_pending"))
|
||||
decode_budget_remaining = int(self.snapshot_state().get("decode_budget_remaining", 0))
|
||||
policy_allowed = bool(policy_snapshot.get("allowed", True))
|
||||
if (
|
||||
worker_decode_control_enabled
|
||||
and worker_decode_has_work
|
||||
and policy_allowed
|
||||
and decode_budget_remaining > 0
|
||||
and (worker_running_requests > 0 or worker_pending_jobs > 0)
|
||||
):
|
||||
return "decode_runtime", "worker_active_batch_progress", policy_snapshot, worker_state
|
||||
if (
|
||||
worker_decode_control_enabled
|
||||
and worker_pending_jobs > 0
|
||||
and policy_allowed
|
||||
and decode_runtime_pending_age_ms >= float(self.arbiter_config.prepare_aging_ms)
|
||||
):
|
||||
return "decode_runtime", "decode_runtime_pending_aging", policy_snapshot, worker_state
|
||||
if (
|
||||
decode_waiting > 0
|
||||
and policy_allowed
|
||||
and (not worker_decode_control_enabled or not worker_decode_has_work or worker_pending_jobs <= 0)
|
||||
):
|
||||
return "decode_dispatch", "dispatch_prepared_state", policy_snapshot, worker_state
|
||||
if (
|
||||
finalize_waiting > 0
|
||||
and prepare_ref_spec_waiting > 0
|
||||
and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0)
|
||||
):
|
||||
return "prepare_ref_spec", "finalize_waiting_for_ref_spec", policy_snapshot, worker_state
|
||||
if finalize_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
|
||||
return "finalize", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
if finalize_waiting > 0 and finalize_age_ms >= float(self.arbiter_config.finalize_aging_ms):
|
||||
return "finalize", "finalize_aging", policy_snapshot, worker_state
|
||||
if prepare_waiting > 0 and (decode_waiting <= 0 or not policy_allowed or decode_budget_remaining <= 0):
|
||||
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
|
||||
return "prepare_text", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0 and prepare_text_waiting <= 0:
|
||||
return "prepare_ref_spec", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
return "prepare_audio", "decode_blocked_or_budget_exhausted", policy_snapshot, worker_state
|
||||
if prepare_waiting > 0 and prepare_age_ms >= float(self.arbiter_config.prepare_aging_ms):
|
||||
if prepare_text_waiting > 0 and prepare_text_age_ms >= max(prepare_audio_age_ms, prepare_age_ms - 1e-6):
|
||||
return "prepare_text", "prepare_aging", policy_snapshot, worker_state
|
||||
if (
|
||||
prepare_ref_spec_waiting > 0
|
||||
and prepare_ref_spec_age_ms >= max(prepare_audio_age_ms, prepare_text_age_ms, prepare_age_ms - 1e-6)
|
||||
):
|
||||
return "prepare_ref_spec", "prepare_aging", policy_snapshot, worker_state
|
||||
return "prepare_audio", "prepare_aging", policy_snapshot, worker_state
|
||||
if worker_decode_control_enabled and worker_decode_has_work and policy_allowed:
|
||||
return "decode_runtime", "worker_active_batch_progress_fallback", policy_snapshot, worker_state
|
||||
if decode_waiting > 0 and policy_allowed:
|
||||
return "decode_dispatch", "decode_priority_fallback", policy_snapshot, worker_state
|
||||
if finalize_waiting > 0:
|
||||
return "finalize", "finalize_fallback", policy_snapshot, worker_state
|
||||
if prepare_waiting > 0:
|
||||
if prepare_text_waiting > 0 and (prepare_audio_waiting <= 0 or prepare_text_age_ms >= prepare_audio_age_ms):
|
||||
return "prepare_text", "prepare_fallback", policy_snapshot, worker_state
|
||||
if prepare_ref_spec_waiting > 0 and prepare_audio_waiting <= 0:
|
||||
return "prepare_ref_spec", "prepare_fallback", policy_snapshot, worker_state
|
||||
return "prepare_audio", "prepare_fallback", policy_snapshot, worker_state
|
||||
return "idle", "no_pending_work", policy_snapshot, worker_state
|
||||
382
GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py
Normal file
382
GPT_SoVITS/TTS_infer_pack/unified_engine_component_registry.py
Normal file
@ -0,0 +1,382 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Deque, Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultReferenceState:
|
||||
ref_audio_path: str | None = None
|
||||
updated_at: float = 0.0
|
||||
|
||||
|
||||
class ReferenceRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._state = DefaultReferenceState()
|
||||
|
||||
def set_default(self, ref_audio_path: str) -> DefaultReferenceState:
|
||||
with self._lock:
|
||||
self._state = DefaultReferenceState(ref_audio_path=str(ref_audio_path), updated_at=time.time())
|
||||
return self._state
|
||||
|
||||
def clear(self) -> DefaultReferenceState:
|
||||
with self._lock:
|
||||
self._state = DefaultReferenceState()
|
||||
return self._state
|
||||
|
||||
def get_default(self) -> DefaultReferenceState:
|
||||
with self._lock:
|
||||
return DefaultReferenceState(
|
||||
ref_audio_path=self._state.ref_audio_path,
|
||||
updated_at=self._state.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRegistryState:
|
||||
t2s_weights_path: str
|
||||
vits_weights_path: str
|
||||
generation: int = 0
|
||||
t2s_generation: int = 0
|
||||
vits_generation: int = 0
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
def __init__(self, t2s_weights_path: str, vits_weights_path: str) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._state = ModelRegistryState(
|
||||
t2s_weights_path=str(t2s_weights_path),
|
||||
vits_weights_path=str(vits_weights_path),
|
||||
)
|
||||
|
||||
def snapshot(self) -> ModelRegistryState:
|
||||
with self._lock:
|
||||
return ModelRegistryState(
|
||||
t2s_weights_path=self._state.t2s_weights_path,
|
||||
vits_weights_path=self._state.vits_weights_path,
|
||||
generation=self._state.generation,
|
||||
t2s_generation=self._state.t2s_generation,
|
||||
vits_generation=self._state.vits_generation,
|
||||
updated_at=self._state.updated_at,
|
||||
)
|
||||
|
||||
def mark_t2s_reload(self, weights_path: str) -> ModelRegistryState:
|
||||
with self._lock:
|
||||
self._state.t2s_weights_path = str(weights_path)
|
||||
self._state.generation += 1
|
||||
self._state.t2s_generation += 1
|
||||
self._state.updated_at = time.time()
|
||||
return ModelRegistryState(
|
||||
t2s_weights_path=self._state.t2s_weights_path,
|
||||
vits_weights_path=self._state.vits_weights_path,
|
||||
generation=self._state.generation,
|
||||
t2s_generation=self._state.t2s_generation,
|
||||
vits_generation=self._state.vits_generation,
|
||||
updated_at=self._state.updated_at,
|
||||
)
|
||||
|
||||
def mark_vits_reload(self, weights_path: str) -> ModelRegistryState:
|
||||
with self._lock:
|
||||
self._state.vits_weights_path = str(weights_path)
|
||||
self._state.generation += 1
|
||||
self._state.vits_generation += 1
|
||||
self._state.updated_at = time.time()
|
||||
return ModelRegistryState(
|
||||
t2s_weights_path=self._state.t2s_weights_path,
|
||||
vits_weights_path=self._state.vits_weights_path,
|
||||
generation=self._state.generation,
|
||||
t2s_generation=self._state.t2s_generation,
|
||||
vits_generation=self._state.vits_generation,
|
||||
updated_at=self._state.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class EngineStatus:
|
||||
NEW = "NEW"
|
||||
QUEUED = "QUEUED"
|
||||
VALIDATED = "VALIDATED"
|
||||
CPU_PREPARING = "CPU_PREPARING"
|
||||
GPU_PREPARING = "GPU_PREPARING"
|
||||
READY_FOR_PREFILL = "READY_FOR_PREFILL"
|
||||
ACTIVE_DECODE = "ACTIVE_DECODE"
|
||||
READY_FOR_FINALIZE = "READY_FOR_FINALIZE"
|
||||
FINALIZING = "FINALIZING"
|
||||
STREAMING = "STREAMING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineRequestState:
|
||||
request_id: str
|
||||
api_mode: str
|
||||
backend: str
|
||||
media_type: str
|
||||
response_streaming: bool
|
||||
submit_ts: float
|
||||
deadline_ts: float | None = None
|
||||
status: str = EngineStatus.NEW
|
||||
updated_ts: float = 0.0
|
||||
error: str | None = None
|
||||
finish_reason: str | None = None
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
profile: Dict[str, Any] = field(default_factory=dict)
|
||||
lifecycle_timestamps: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_summary(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"api_mode": self.api_mode,
|
||||
"backend": self.backend,
|
||||
"media_type": self.media_type,
|
||||
"response_streaming": self.response_streaming,
|
||||
"status": self.status,
|
||||
"submit_ts": self.submit_ts,
|
||||
"updated_ts": self.updated_ts,
|
||||
"deadline_ts": self.deadline_ts,
|
||||
"error": self.error,
|
||||
"finish_reason": self.finish_reason,
|
||||
"meta": dict(self.meta),
|
||||
"profile": dict(self.profile),
|
||||
"lifecycle_timestamps": dict(self.lifecycle_timestamps),
|
||||
}
|
||||
|
||||
|
||||
class EngineRequestRegistry:
|
||||
def __init__(self, recent_limit: int) -> None:
|
||||
self.lock = threading.Lock()
|
||||
self.active_requests: Dict[str, EngineRequestState] = {}
|
||||
self.recent_requests: Deque[EngineRequestState] = deque()
|
||||
self.recent_limit = max(1, int(recent_limit))
|
||||
|
||||
def register(
|
||||
self,
|
||||
*,
|
||||
request_id: str,
|
||||
api_mode: str,
|
||||
backend: str,
|
||||
media_type: str,
|
||||
response_streaming: bool,
|
||||
deadline_ts: float | None = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
) -> EngineRequestState:
|
||||
now = time.perf_counter()
|
||||
state = EngineRequestState(
|
||||
request_id=request_id,
|
||||
api_mode=api_mode,
|
||||
backend=backend,
|
||||
media_type=media_type,
|
||||
response_streaming=bool(response_streaming),
|
||||
submit_ts=now,
|
||||
deadline_ts=deadline_ts,
|
||||
updated_ts=now,
|
||||
meta=dict(meta or {}),
|
||||
lifecycle_timestamps={EngineStatus.NEW: now},
|
||||
)
|
||||
with self.lock:
|
||||
self.active_requests[request_id] = state
|
||||
return state
|
||||
|
||||
def _move_to_recent_locked(self, state: EngineRequestState) -> None:
|
||||
self.recent_requests.appendleft(state)
|
||||
while len(self.recent_requests) > self.recent_limit:
|
||||
self.recent_requests.pop()
|
||||
|
||||
@staticmethod
|
||||
def _apply_state_extra(state: EngineRequestState, extra: Optional[Dict[str, Any]]) -> None:
|
||||
if not extra:
|
||||
return
|
||||
payload = dict(extra)
|
||||
backend = payload.pop("backend", None)
|
||||
if backend is not None:
|
||||
state.backend = str(backend)
|
||||
finish_reason = payload.pop("finish_reason", None)
|
||||
if finish_reason is not None:
|
||||
state.finish_reason = str(finish_reason)
|
||||
error = payload.pop("error", None)
|
||||
if error is not None:
|
||||
state.error = str(error)
|
||||
state.profile.update(payload)
|
||||
|
||||
def update(self, request_id: str, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
now = time.perf_counter()
|
||||
with self.lock:
|
||||
state = self.active_requests.get(request_id)
|
||||
if state is None:
|
||||
return
|
||||
state.status = str(status)
|
||||
state.updated_ts = now
|
||||
state.lifecycle_timestamps[str(status)] = now
|
||||
self._apply_state_extra(state, extra)
|
||||
|
||||
def merge_profile(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not extra:
|
||||
return
|
||||
now = time.perf_counter()
|
||||
with self.lock:
|
||||
state = self.active_requests.get(request_id)
|
||||
if state is None:
|
||||
for recent_state in self.recent_requests:
|
||||
if recent_state.request_id == request_id:
|
||||
state = recent_state
|
||||
break
|
||||
if state is None:
|
||||
return
|
||||
state.updated_ts = now
|
||||
self._apply_state_extra(state, extra)
|
||||
|
||||
def complete(self, request_id: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
now = time.perf_counter()
|
||||
with self.lock:
|
||||
state = self.active_requests.pop(request_id, None)
|
||||
if state is None:
|
||||
return
|
||||
state.status = EngineStatus.COMPLETED
|
||||
state.updated_ts = now
|
||||
state.lifecycle_timestamps[EngineStatus.COMPLETED] = now
|
||||
self._apply_state_extra(state, extra)
|
||||
self._move_to_recent_locked(state)
|
||||
|
||||
def fail(self, request_id: str, error: str) -> None:
|
||||
now = time.perf_counter()
|
||||
with self.lock:
|
||||
state = self.active_requests.pop(request_id, None)
|
||||
if state is None:
|
||||
return
|
||||
state.status = EngineStatus.FAILED
|
||||
state.updated_ts = now
|
||||
state.error = str(error)
|
||||
state.lifecycle_timestamps[EngineStatus.FAILED] = now
|
||||
self._move_to_recent_locked(state)
|
||||
|
||||
def snapshot(self) -> Dict[str, Any]:
|
||||
with self.lock:
|
||||
active = [state.to_summary() for state in self.active_requests.values()]
|
||||
recent = [state.to_summary() for state in list(self.recent_requests)]
|
||||
recent_limit = self.recent_limit
|
||||
active.sort(key=lambda item: item["submit_ts"])
|
||||
return {
|
||||
"active_count": len(active),
|
||||
"recent_count": len(recent),
|
||||
"recent_limit": recent_limit,
|
||||
"active_requests": active,
|
||||
"recent_requests": recent,
|
||||
}
|
||||
|
||||
def collect_summaries(self, request_ids: Sequence[str]) -> list[Dict[str, Any]]:
|
||||
requested = set(request_ids)
|
||||
results: list[Dict[str, Any]] = []
|
||||
with self.lock:
|
||||
for state in self.active_requests.values():
|
||||
if state.request_id in requested:
|
||||
results.append(state.to_summary())
|
||||
existing_ids = {item["request_id"] for item in results}
|
||||
for state in self.recent_requests:
|
||||
if state.request_id in requested and state.request_id not in existing_ids:
|
||||
results.append(state.to_summary())
|
||||
results.sort(key=lambda item: item["request_id"])
|
||||
return results
|
||||
|
||||
def has_active(self, request_id: str) -> bool:
|
||||
with self.lock:
|
||||
return request_id in self.active_requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerPendingJob:
|
||||
request_id: str
|
||||
state: T2SRequestState
|
||||
done_event: threading.Event
|
||||
done_loop: asyncio.AbstractEventLoop | None
|
||||
done_future: asyncio.Future | None
|
||||
enqueue_time: float
|
||||
speed_factor: float
|
||||
sample_steps: int
|
||||
media_type: str
|
||||
super_sampling: bool = False
|
||||
admission_wait_ms: float = 0.0
|
||||
engine_policy_wait_ms: float = 0.0
|
||||
engine_dispatch_wait_ms: float = 0.0
|
||||
prepare_wall_ms: float = 0.0
|
||||
prepare_profile_total_ms: float = 0.0
|
||||
first_schedule_time: float | None = None
|
||||
prefill_ms: float = 0.0
|
||||
merge_ms: float = 0.0
|
||||
decode_ms: float = 0.0
|
||||
finalize_wait_ms: float = 0.0
|
||||
synth_ms: float = 0.0
|
||||
pack_ms: float = 0.0
|
||||
decode_steps: int = 0
|
||||
result_ready_time: float | None = None
|
||||
result: dict | None = None
|
||||
sample_rate: int | None = None
|
||||
audio_data: np.ndarray | None = None
|
||||
error: str | None = None
|
||||
engine_request_id: str | None = None
|
||||
|
||||
|
||||
class SchedulerJobRegistry:
|
||||
def __init__(self, lock: threading.Lock | threading.RLock | threading.Condition) -> None:
|
||||
self._lock = lock
|
||||
self._job_map: Dict[str, SchedulerPendingJob] = {}
|
||||
self._total_submitted = 0
|
||||
self._total_finished = 0
|
||||
|
||||
def register(self, job: SchedulerPendingJob, *, keep_job: bool = True) -> None:
|
||||
with self._lock:
|
||||
if keep_job:
|
||||
self._job_map[job.request_id] = job
|
||||
self._total_submitted += 1
|
||||
|
||||
def get(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
with self._lock:
|
||||
return self._job_map.get(request_id)
|
||||
|
||||
def pop(self, request_id: str) -> SchedulerPendingJob | None:
|
||||
with self._lock:
|
||||
return self._job_map.pop(request_id, None)
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
with self._lock:
|
||||
self._job_map.pop(request_id, None)
|
||||
|
||||
def mark_finished(self) -> None:
|
||||
with self._lock:
|
||||
self._total_finished += 1
|
||||
|
||||
def mark_finished_and_remove(self, request_id: str) -> None:
|
||||
with self._lock:
|
||||
self._job_map.pop(request_id, None)
|
||||
self._total_finished += 1
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
with self._lock:
|
||||
return not self._job_map
|
||||
|
||||
def submitted_count(self) -> int:
|
||||
with self._lock:
|
||||
return int(self._total_submitted)
|
||||
|
||||
def finished_count(self) -> int:
|
||||
with self._lock:
|
||||
return int(self._total_finished)
|
||||
|
||||
def snapshot(self, max_request_ids: int = 32) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
request_ids = list(self._job_map.keys())
|
||||
return {
|
||||
"job_count": int(len(request_ids)),
|
||||
"request_ids": request_ids[: max(0, int(max_request_ids))],
|
||||
"total_submitted": int(self._total_submitted),
|
||||
"total_finished": int(self._total_finished),
|
||||
}
|
||||
362
GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py
Normal file
362
GPT_SoVITS/TTS_infer_pack/unified_engine_component_runtime.py
Normal file
@ -0,0 +1,362 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Deque, Dict, List, Optional, Sequence
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineTaskQueueOwner:
|
||||
def __init__(self, completion_key: str = "total_completed") -> None:
|
||||
self.condition = threading.Condition()
|
||||
self.queue: Deque[Any] = deque()
|
||||
self.total_submitted = 0
|
||||
self.total_completed = 0
|
||||
self.peak_waiting = 0
|
||||
self.completion_key = str(completion_key)
|
||||
|
||||
def enqueue(self, item: Any) -> None:
|
||||
with self.condition:
|
||||
self.queue.append(item)
|
||||
self.total_submitted += 1
|
||||
self.peak_waiting = max(self.peak_waiting, len(self.queue))
|
||||
self.condition.notify_all()
|
||||
|
||||
def enqueue_many(self, items: Sequence[Any]) -> None:
|
||||
if not items:
|
||||
return
|
||||
with self.condition:
|
||||
for item in items:
|
||||
self.queue.append(item)
|
||||
self.total_submitted += len(items)
|
||||
self.peak_waiting = max(self.peak_waiting, len(self.queue))
|
||||
self.condition.notify_all()
|
||||
|
||||
def pop_left(self) -> Any | None:
|
||||
with self.condition:
|
||||
if not self.queue:
|
||||
return None
|
||||
return self.queue.popleft()
|
||||
|
||||
def pop_left_many(self, max_items: int) -> List[Any]:
|
||||
limit = max(1, int(max_items))
|
||||
with self.condition:
|
||||
if not self.queue:
|
||||
return []
|
||||
selected: List[Any] = []
|
||||
while self.queue and len(selected) < limit:
|
||||
selected.append(self.queue.popleft())
|
||||
return selected
|
||||
|
||||
def mark_completed(self, count: int = 1, *, notify: bool = False) -> None:
|
||||
if count <= 0:
|
||||
return
|
||||
with self.condition:
|
||||
self.total_completed += int(count)
|
||||
if notify:
|
||||
self.condition.notify_all()
|
||||
|
||||
def has_items(self) -> bool:
|
||||
with self.condition:
|
||||
return bool(self.queue)
|
||||
|
||||
def waiting_count(self) -> int:
|
||||
with self.condition:
|
||||
return int(len(self.queue))
|
||||
|
||||
def snapshot(self, *, max_request_ids: int = 16, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
with self.condition:
|
||||
waiting_items = list(self.queue)[: max(0, int(max_request_ids))]
|
||||
snapshot = {
|
||||
"waiting_count": int(len(self.queue)),
|
||||
"waiting_request_ids": [str(getattr(item, "request_id", "")) for item in waiting_items],
|
||||
"peak_waiting": int(self.peak_waiting),
|
||||
"total_submitted": int(self.total_submitted),
|
||||
self.completion_key: int(self.total_completed),
|
||||
}
|
||||
if extra:
|
||||
snapshot.update(dict(extra))
|
||||
return snapshot
|
||||
|
||||
def peek_oldest_age_ms(self, timestamp_attr: str) -> float:
|
||||
with self.condition:
|
||||
if not self.queue:
|
||||
return 0.0
|
||||
enqueue_time = float(getattr(self.queue[0], timestamp_attr))
|
||||
return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0)
|
||||
|
||||
def is_drained(self) -> bool:
|
||||
with self.condition:
|
||||
return not self.queue and self.total_submitted == self.total_completed
|
||||
|
||||
def take_finalize_batch(
|
||||
self,
|
||||
*,
|
||||
finalize_mode: str,
|
||||
batch_max_items: int,
|
||||
batch_wait_s: float,
|
||||
use_vocoder: bool,
|
||||
) -> List[SchedulerFinalizeTask]:
|
||||
with self.condition:
|
||||
if not self.queue:
|
||||
return []
|
||||
selected_tasks = [self.queue.popleft()]
|
||||
if finalize_mode == "sync" or use_vocoder:
|
||||
return selected_tasks
|
||||
if batch_max_items <= 1:
|
||||
return selected_tasks
|
||||
first_task = selected_tasks[0]
|
||||
oldest_age_s = max(0.0, time.perf_counter() - first_task.enqueued_time)
|
||||
if len(self.queue) + 1 < batch_max_items and oldest_age_s < batch_wait_s:
|
||||
self.queue.appendleft(first_task)
|
||||
return []
|
||||
while len(selected_tasks) < batch_max_items:
|
||||
if not self.queue:
|
||||
break
|
||||
matched_index = None
|
||||
for index, task in enumerate(self.queue):
|
||||
if abs(task.enqueued_time - first_task.enqueued_time) < 1.0:
|
||||
matched_index = index
|
||||
break
|
||||
if matched_index is None:
|
||||
break
|
||||
selected_tasks.append(self.queue[matched_index])
|
||||
del self.queue[matched_index]
|
||||
return selected_tasks
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineDecodeRuntimeState:
|
||||
pending_jobs: int = 0
|
||||
pending_request_ids: List[str] = field(default_factory=list)
|
||||
active_request_count: int = 0
|
||||
active_request_ids: List[str] = field(default_factory=list)
|
||||
prefill_done: bool = False
|
||||
decode_step_index_max: int = 0
|
||||
total_cycles: int = 0
|
||||
prefill_cycles: int = 0
|
||||
step_cycles: int = 0
|
||||
has_work: bool = False
|
||||
last_event: str = "init"
|
||||
updated_at: float = 0.0
|
||||
|
||||
|
||||
class EngineDecodeRuntimeOwner:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
get_decode_runtime_counters: Callable[[], Dict[str, int]],
|
||||
get_micro_batch_wait_s: Callable[[], float],
|
||||
) -> None:
|
||||
self.get_decode_runtime_counters = get_decode_runtime_counters
|
||||
self.get_micro_batch_wait_s = get_micro_batch_wait_s
|
||||
self.condition = threading.Condition()
|
||||
self.pending_jobs: Deque[SchedulerPendingJob] = deque()
|
||||
self.active_batch: T2SActiveBatch | None = None
|
||||
self.state_lock = threading.Lock()
|
||||
self.state = EngineDecodeRuntimeState(updated_at=time.perf_counter())
|
||||
|
||||
@staticmethod
|
||||
def summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any]:
|
||||
if active_batch is None:
|
||||
return {}
|
||||
decode_step_index_max = 0
|
||||
if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0:
|
||||
decode_step_index_max = int(active_batch.step_indices.max().item())
|
||||
return {
|
||||
"request_count": int(len(active_batch.request_ids)),
|
||||
"request_ids": list(active_batch.request_ids),
|
||||
"prefill_done": bool(active_batch.prefill_done),
|
||||
"decode_step_index_max": int(decode_step_index_max),
|
||||
}
|
||||
|
||||
def snapshot_pending_queue_state(self) -> Dict[str, Any]:
|
||||
with self.condition:
|
||||
return {
|
||||
"pending_jobs": int(len(self.pending_jobs)),
|
||||
"pending_request_ids": [job.request_id for job in list(self.pending_jobs)[:32]],
|
||||
}
|
||||
|
||||
def enqueue_pending_job(self, job: SchedulerPendingJob) -> None:
|
||||
with self.condition:
|
||||
self.pending_jobs.append(job)
|
||||
self.condition.notify_all()
|
||||
self.refresh_state("engine_decode_pending_enqueue")
|
||||
|
||||
def take_pending_jobs_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
with self.condition:
|
||||
if not self.pending_jobs:
|
||||
return []
|
||||
if wait_for_batch:
|
||||
oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time)
|
||||
if (time.perf_counter() - oldest_enqueue_time) < self.get_micro_batch_wait_s():
|
||||
return []
|
||||
pending_jobs = list(self.pending_jobs)
|
||||
self.pending_jobs.clear()
|
||||
self.refresh_state("engine_decode_pending_dequeue")
|
||||
return pending_jobs
|
||||
|
||||
def pending_age_ms(self) -> float:
|
||||
with self.condition:
|
||||
if not self.pending_jobs:
|
||||
return 0.0
|
||||
enqueue_time = float(self.pending_jobs[0].enqueue_time)
|
||||
return max(0.0, (time.perf_counter() - enqueue_time) * 1000.0)
|
||||
|
||||
def has_pending_jobs(self) -> bool:
|
||||
with self.condition:
|
||||
return bool(self.pending_jobs)
|
||||
|
||||
def get_active_batch(self) -> T2SActiveBatch | None:
|
||||
return self.active_batch
|
||||
|
||||
def set_active_batch(self, active_batch: T2SActiveBatch | None) -> None:
|
||||
self.active_batch = active_batch
|
||||
|
||||
def active_batch_summary(self) -> Dict[str, Any]:
|
||||
return self.summarize_active_batch(self.active_batch)
|
||||
|
||||
def refresh_state(self, last_event: str) -> None:
|
||||
pending_state = self.snapshot_pending_queue_state()
|
||||
active_batch_summary = self.active_batch_summary()
|
||||
worker_decode_counters = self.get_decode_runtime_counters()
|
||||
with self.state_lock:
|
||||
self.state.pending_jobs = int(pending_state.get("pending_jobs", 0))
|
||||
self.state.pending_request_ids = list(pending_state.get("pending_request_ids", []))
|
||||
self.state.active_request_count = int(active_batch_summary.get("request_count", 0))
|
||||
self.state.active_request_ids = list(active_batch_summary.get("request_ids", []))[:32]
|
||||
self.state.prefill_done = bool(active_batch_summary.get("prefill_done", False))
|
||||
self.state.decode_step_index_max = int(active_batch_summary.get("decode_step_index_max", 0))
|
||||
self.state.total_cycles = int(worker_decode_counters.get("total_cycles", 0))
|
||||
self.state.prefill_cycles = int(worker_decode_counters.get("prefill_cycles", 0))
|
||||
self.state.step_cycles = int(worker_decode_counters.get("step_cycles", 0))
|
||||
self.state.has_work = bool(pending_state.get("pending_jobs", 0) or active_batch_summary.get("request_count", 0))
|
||||
self.state.last_event = str(last_event)
|
||||
self.state.updated_at = float(time.perf_counter())
|
||||
|
||||
def update_from_worker_snapshot(self, snapshot: Dict[str, Any]) -> None:
|
||||
if not snapshot:
|
||||
return
|
||||
pending_state = self.snapshot_pending_queue_state()
|
||||
with self.state_lock:
|
||||
self.state.pending_jobs = int(pending_state.get("pending_jobs", 0))
|
||||
self.state.pending_request_ids = list(pending_state.get("pending_request_ids", []))
|
||||
self.state.active_request_count = int(snapshot.get("active_request_count", 0))
|
||||
self.state.active_request_ids = list(snapshot.get("active_request_ids", []))[:32]
|
||||
self.state.prefill_done = bool(snapshot.get("prefill_done", False))
|
||||
self.state.decode_step_index_max = int(snapshot.get("decode_step_index_max", 0))
|
||||
self.state.total_cycles = int(snapshot.get("total_cycles", 0))
|
||||
self.state.prefill_cycles = int(snapshot.get("prefill_cycles", 0))
|
||||
self.state.step_cycles = int(snapshot.get("step_cycles", 0))
|
||||
self.state.has_work = bool(
|
||||
pending_state.get("pending_jobs", 0)
|
||||
or snapshot.get("active_request_count", 0)
|
||||
or snapshot.get("has_work", False)
|
||||
)
|
||||
self.state.last_event = str(snapshot.get("last_event", "unknown"))
|
||||
self.state.updated_at = float(snapshot.get("updated_at", time.perf_counter()))
|
||||
|
||||
def snapshot_state(self) -> Dict[str, Any]:
|
||||
pending_state = self.snapshot_pending_queue_state()
|
||||
active_batch_summary = self.active_batch_summary()
|
||||
worker_decode_counters = self.get_decode_runtime_counters()
|
||||
with self.state_lock:
|
||||
return {
|
||||
"pending_jobs": int(pending_state.get("pending_jobs", self.state.pending_jobs)),
|
||||
"pending_request_ids": list(pending_state.get("pending_request_ids", self.state.pending_request_ids)),
|
||||
"active_request_count": int(active_batch_summary.get("request_count", self.state.active_request_count)),
|
||||
"active_request_ids": list(active_batch_summary.get("request_ids", self.state.active_request_ids)),
|
||||
"prefill_done": bool(active_batch_summary.get("prefill_done", self.state.prefill_done)),
|
||||
"decode_step_index_max": int(active_batch_summary.get("decode_step_index_max", self.state.decode_step_index_max)),
|
||||
"total_cycles": int(worker_decode_counters.get("total_cycles", 0)),
|
||||
"prefill_cycles": int(worker_decode_counters.get("prefill_cycles", 0)),
|
||||
"step_cycles": int(worker_decode_counters.get("step_cycles", 0)),
|
||||
"has_work": bool(
|
||||
pending_state.get("pending_jobs", 0)
|
||||
or active_batch_summary.get("request_count", self.state.active_request_count)
|
||||
or self.state.has_work
|
||||
),
|
||||
"last_event": str(self.state.last_event),
|
||||
"updated_at": float(self.state.updated_at),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerFinalizeTask:
|
||||
request_id: str
|
||||
item: T2SFinishedItem
|
||||
enqueued_time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineDispatchTask:
|
||||
request_id: str
|
||||
state: T2SRequestState
|
||||
speed_factor: float
|
||||
sample_steps: int
|
||||
media_type: str
|
||||
super_sampling: bool
|
||||
prepare_wall_ms: float
|
||||
prepare_profile_total_ms: float
|
||||
done_loop: asyncio.AbstractEventLoop | None
|
||||
done_future: asyncio.Future | None
|
||||
engine_request_id: str | None
|
||||
timeout_sec: float | None
|
||||
enqueue_time: float
|
||||
worker_job: SchedulerPendingJob | None = None
|
||||
engine_policy_wait_ms: float = 0.0
|
||||
engine_dispatch_wait_ms: float = 0.0
|
||||
engine_policy_snapshot: Dict[str, Any] | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineGpuPrepareTask:
|
||||
request_id: str
|
||||
cpu_stage: PreparedCpuStage
|
||||
done_loop: asyncio.AbstractEventLoop | None
|
||||
done_future: asyncio.Future | None
|
||||
engine_request_id: str | None
|
||||
enqueue_time: float
|
||||
phase: str = "audio"
|
||||
audio_enqueue_time: float = 0.0
|
||||
audio_start_time: float = 0.0
|
||||
audio_end_time: float = 0.0
|
||||
text_enqueue_time: float = 0.0
|
||||
text_start_time: float = 0.0
|
||||
text_end_time: float = 0.0
|
||||
ref_spec_enqueue_time: float = 0.0
|
||||
ref_spec_start_time: float = 0.0
|
||||
ref_spec_end_time: float = 0.0
|
||||
audio_queue_wait_ms: float = 0.0
|
||||
text_queue_wait_ms: float = 0.0
|
||||
ref_spec_queue_wait_ms: float = 0.0
|
||||
admission_wait_ms: float = 0.0
|
||||
phase_one: Dict[str, Any] | None = None
|
||||
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]] | None = None
|
||||
state_result: T2SRequestState | None = None
|
||||
cancelled: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineFinalizeQueueState:
|
||||
waiting_count: int
|
||||
waiting_request_ids: List[str]
|
||||
peak_waiting: int
|
||||
total_submitted: int
|
||||
total_completed: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeStateCallbacks:
|
||||
update: Callable[[str, str, Optional[Dict[str, Any]]], None] | None = None
|
||||
complete: Callable[[str, Optional[Dict[str, Any]]], None] | None = None
|
||||
fail: Callable[[str, str], None] | None = None
|
||||
decode_runtime_update: Callable[[Dict[str, Any]], None] | None = None
|
||||
63
GPT_SoVITS/TTS_infer_pack/unified_engine_components.py
Normal file
63
GPT_SoVITS/TTS_infer_pack/unified_engine_components.py
Normal file
@ -0,0 +1,63 @@
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_models import (
|
||||
DirectTTSExecution,
|
||||
NormalizedEngineRequest,
|
||||
RuntimeControlCallbacks,
|
||||
SchedulerDebugExecution,
|
||||
SchedulerSubmitExecution,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_policy import (
|
||||
EngineArbiterConfig,
|
||||
EngineArbiterState,
|
||||
EnginePolicyArbiterController,
|
||||
EnginePolicyConfig,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_registry import (
|
||||
DefaultReferenceState,
|
||||
EngineRequestRegistry,
|
||||
EngineRequestState,
|
||||
EngineStatus,
|
||||
ModelRegistry,
|
||||
ModelRegistryState,
|
||||
ReferenceRegistry,
|
||||
SchedulerJobRegistry,
|
||||
SchedulerPendingJob,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_component_runtime import (
|
||||
EngineDecodeRuntimeOwner,
|
||||
EngineDecodeRuntimeState,
|
||||
EngineDispatchTask,
|
||||
EngineFinalizeQueueState,
|
||||
EngineGpuPrepareTask,
|
||||
EngineTaskQueueOwner,
|
||||
RuntimeStateCallbacks,
|
||||
SchedulerFinalizeTask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DefaultReferenceState",
|
||||
"DirectTTSExecution",
|
||||
"EngineArbiterConfig",
|
||||
"EngineArbiterState",
|
||||
"EngineDecodeRuntimeOwner",
|
||||
"EngineDecodeRuntimeState",
|
||||
"EngineDispatchTask",
|
||||
"EngineFinalizeQueueState",
|
||||
"EngineGpuPrepareTask",
|
||||
"EnginePolicyArbiterController",
|
||||
"EnginePolicyConfig",
|
||||
"EngineRequestRegistry",
|
||||
"EngineRequestState",
|
||||
"EngineStatus",
|
||||
"EngineTaskQueueOwner",
|
||||
"ModelRegistry",
|
||||
"ModelRegistryState",
|
||||
"NormalizedEngineRequest",
|
||||
"ReferenceRegistry",
|
||||
"RuntimeControlCallbacks",
|
||||
"RuntimeStateCallbacks",
|
||||
"SchedulerDebugExecution",
|
||||
"SchedulerFinalizeTask",
|
||||
"SchedulerJobRegistry",
|
||||
"SchedulerPendingJob",
|
||||
"SchedulerSubmitExecution",
|
||||
]
|
||||
9
GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py
Normal file
9
GPT_SoVITS/TTS_infer_pack/unified_engine_delegates.py
Normal file
@ -0,0 +1,9 @@
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_api_delegates import EngineApiDelegates
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_bridge_delegates import EngineBridgeDelegates
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime_delegates import EngineRuntimeDelegates
|
||||
|
||||
__all__ = [
|
||||
"EngineApiDelegates",
|
||||
"EngineBridgeDelegates",
|
||||
"EngineRuntimeDelegates",
|
||||
]
|
||||
116
GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py
Normal file
116
GPT_SoVITS/TTS_infer_pack/unified_engine_orchestration.py
Normal file
@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDecodeRuntimeOwner, EngineTaskQueueOwner
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
|
||||
|
||||
|
||||
class EngineStageOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
executor: EngineStageExecutor,
|
||||
scheduler_worker: UnifiedSchedulerWorker,
|
||||
prepare_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_text_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
|
||||
finalize_queue_owner: EngineTaskQueueOwner,
|
||||
dispatch_queue_owner: EngineTaskQueueOwner,
|
||||
decode_runtime_owner: EngineDecodeRuntimeOwner,
|
||||
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
|
||||
) -> None:
|
||||
self.executor = executor
|
||||
self.scheduler_worker = scheduler_worker
|
||||
self.prepare_queue_owner = prepare_queue_owner
|
||||
self.prepare_text_queue_owner = prepare_text_queue_owner
|
||||
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
|
||||
self.finalize_queue_owner = finalize_queue_owner
|
||||
self.dispatch_queue_owner = dispatch_queue_owner
|
||||
self.decode_runtime_owner = decode_runtime_owner
|
||||
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
|
||||
self._select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]] | None = None
|
||||
self._mark_arbiter_tick: Callable[[str, str, bool], None] | None = None
|
||||
self._wait_arbiter: Callable[[], None] | None = None
|
||||
|
||||
def bind_arbiter(
|
||||
self,
|
||||
*,
|
||||
notify_arbiter: Callable[[], None],
|
||||
select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]],
|
||||
mark_arbiter_tick: Callable[[str, str, bool], None],
|
||||
wait_arbiter: Callable[[], None],
|
||||
) -> None:
|
||||
self.executor.bind_notify_arbiter(notify_arbiter)
|
||||
self._select_stage = select_stage
|
||||
self._mark_arbiter_tick = mark_arbiter_tick
|
||||
self._wait_arbiter = wait_arbiter
|
||||
|
||||
def peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
if queue_name == "prepare":
|
||||
return max(
|
||||
self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time"),
|
||||
self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time"),
|
||||
self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time"),
|
||||
)
|
||||
if queue_name == "prepare_audio":
|
||||
return self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if queue_name == "prepare_text":
|
||||
return self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if queue_name == "prepare_ref_spec":
|
||||
return self.prepare_ref_spec_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if queue_name == "finalize":
|
||||
return self.finalize_queue_owner.peek_oldest_age_ms("enqueued_time")
|
||||
if queue_name == "decode_runtime_pending":
|
||||
return self.decode_runtime_owner.pending_age_ms()
|
||||
return self.dispatch_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
|
||||
def has_pending_work(self) -> bool:
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
if self.decode_runtime_owner.has_pending_jobs():
|
||||
return True
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled() and self.snapshot_engine_decode_runtime_state().get(
|
||||
"active_request_count", 0
|
||||
) > 0:
|
||||
return True
|
||||
if self.prepare_queue_owner.has_items():
|
||||
return True
|
||||
if self.prepare_text_queue_owner.has_items():
|
||||
return True
|
||||
if self.prepare_ref_spec_queue_owner.has_items():
|
||||
return True
|
||||
if self.finalize_queue_owner.has_items():
|
||||
return True
|
||||
return self.dispatch_queue_owner.has_items()
|
||||
|
||||
def run_engine_arbiter_loop(self) -> None:
|
||||
if self._select_stage is None or self._mark_arbiter_tick is None or self._wait_arbiter is None:
|
||||
raise RuntimeError("arbiter callbacks are not bound")
|
||||
while True:
|
||||
if not self.has_pending_work():
|
||||
self._mark_arbiter_tick("idle", "no_pending_work", True)
|
||||
self._wait_arbiter()
|
||||
continue
|
||||
stage, reason, policy_snapshot, worker_state = self._select_stage()
|
||||
policy_allowed = bool(policy_snapshot.get("allowed", True))
|
||||
executed = False
|
||||
if stage == "prepare":
|
||||
executed = self.executor.run_engine_prepare_once()
|
||||
elif stage == "prepare_audio":
|
||||
executed = self.executor.run_engine_prepare_audio_once()
|
||||
elif stage == "prepare_text":
|
||||
executed = self.executor.run_engine_prepare_text_once()
|
||||
elif stage == "prepare_ref_spec":
|
||||
executed = self.executor.run_engine_prepare_ref_spec_once()
|
||||
elif stage == "finalize":
|
||||
executed = self.executor.run_engine_finalize_once()
|
||||
elif stage == "decode_dispatch":
|
||||
executed = self.executor.run_engine_dispatch_once(policy_snapshot, worker_state)
|
||||
elif stage == "decode_runtime":
|
||||
executed = self.executor.run_engine_decode_runtime_once()
|
||||
if not executed:
|
||||
self._mark_arbiter_tick("idle", f"{stage}_not_ready", policy_allowed)
|
||||
self._wait_arbiter()
|
||||
continue
|
||||
self._mark_arbiter_tick(stage, reason, policy_allowed)
|
||||
53
GPT_SoVITS/TTS_infer_pack/unified_engine_public.py
Normal file
53
GPT_SoVITS/TTS_infer_pack/unified_engine_public.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import DirectTTSExecution, SchedulerDebugExecution, SchedulerSubmitExecution
|
||||
|
||||
|
||||
class EnginePublicInterface:
|
||||
PUBLIC_API_METHODS = (
|
||||
"run_direct_tts_async",
|
||||
"run_scheduler_submit",
|
||||
"run_scheduler_debug",
|
||||
"get_runtime_state",
|
||||
"set_refer_audio",
|
||||
"set_gpt_weights",
|
||||
"set_sovits_weights",
|
||||
"handle_control",
|
||||
)
|
||||
|
||||
async def run_direct_tts_async(self, req: dict) -> DirectTTSExecution:
|
||||
return await self.api_facade.run_direct_tts_async(req)
|
||||
|
||||
async def run_scheduler_debug(self, request_items: list[dict], max_steps: int, seed: int) -> SchedulerDebugExecution:
|
||||
return await self.api_facade.run_scheduler_debug(request_items, max_steps, seed)
|
||||
|
||||
async def run_scheduler_submit(self, payload: dict) -> SchedulerSubmitExecution:
|
||||
return await self.api_facade.run_scheduler_submit(payload)
|
||||
|
||||
def get_runtime_state(self) -> dict:
|
||||
return self.runtime_facade.get_runtime_state()
|
||||
|
||||
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
|
||||
return self.runtime_facade.set_refer_audio(refer_audio_path)
|
||||
|
||||
def set_gpt_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_gpt_weights(weights_path)
|
||||
|
||||
def set_sovits_weights(self, weights_path: str) -> dict:
|
||||
return self.runtime_facade.set_sovits_weights(weights_path)
|
||||
|
||||
def handle_control(self, command: str) -> None:
|
||||
self.runtime_facade.handle_control(command)
|
||||
|
||||
|
||||
class EngineCompatInterface:
|
||||
COMPAT_API_METHODS = (
|
||||
"run_direct_tts",
|
||||
"get_scheduler_state",
|
||||
)
|
||||
|
||||
def run_direct_tts(self, req: dict) -> DirectTTSExecution:
|
||||
return self.api_facade.run_direct_tts(req)
|
||||
|
||||
def get_scheduler_state(self) -> dict:
|
||||
return self.runtime_facade.get_scheduler_state()
|
||||
198
GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py
Normal file
198
GPT_SoVITS/TTS_infer_pack/unified_engine_runtime.py
Normal file
@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class EngineRuntimeFacade:
|
||||
def __init__(self, owner: Any) -> None:
|
||||
self.owner = owner
|
||||
|
||||
@property
|
||||
def tts(self):
|
||||
return self.owner.tts
|
||||
|
||||
@property
|
||||
def reference_registry(self):
|
||||
return self.owner.reference_registry
|
||||
|
||||
@property
|
||||
def model_registry(self):
|
||||
return self.owner.model_registry
|
||||
|
||||
@property
|
||||
def scheduler_worker(self):
|
||||
return self.owner.scheduler_worker
|
||||
|
||||
@property
|
||||
def engine_decode_runtime_owner(self):
|
||||
return self.owner.engine_decode_runtime_owner
|
||||
|
||||
@property
|
||||
def engine_policy_arbiter(self):
|
||||
return self.owner.engine_policy_arbiter
|
||||
|
||||
@property
|
||||
def management_lock(self):
|
||||
return self.owner.management_lock
|
||||
|
||||
@property
|
||||
def direct_tts_lock(self):
|
||||
return self.owner.direct_tts_lock
|
||||
|
||||
@property
|
||||
def control_callbacks(self):
|
||||
return self.owner.control_callbacks
|
||||
|
||||
@staticmethod
|
||||
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
|
||||
if component is None or not hasattr(component, "snapshot"):
|
||||
return None
|
||||
try:
|
||||
return dict(component.snapshot())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _build_stage_counters(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.engine_policy_arbiter.build_stage_counters(request_registry, worker_state)
|
||||
|
||||
def _build_engine_policy_snapshot(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.engine_policy_arbiter.build_policy_snapshot(request_registry, worker_state)
|
||||
|
||||
def _build_stage_summary(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
counters = self._build_stage_counters(request_registry, worker_state)
|
||||
bert_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_bert_batch_worker", None))
|
||||
ref_semantic_worker_state = self._safe_component_snapshot(getattr(self.tts, "prepare_ref_semantic_batch_worker", None))
|
||||
text_preprocessor_state = self._safe_component_snapshot(getattr(self.tts, "text_preprocessor", None))
|
||||
|
||||
return {
|
||||
**counters,
|
||||
"engine_drained": bool(self.owner._is_engine_drained()),
|
||||
"admission_config": {
|
||||
"decode_backlog_max": int(worker_state.get("decode_backlog_max", 0)),
|
||||
"finalize_pending_max": int(worker_state.get("finalize_pending_max", 0)),
|
||||
},
|
||||
"engine_policy": self._build_engine_policy_snapshot(request_registry, worker_state),
|
||||
"engine_arbiter_state": self.owner._snapshot_engine_arbiter_state(),
|
||||
"engine_decode_runtime_state": self.owner._snapshot_engine_decode_runtime_state(),
|
||||
"engine_job_registry": self.owner._snapshot_engine_job_registry(),
|
||||
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
|
||||
"engine_prepare_state": self.owner._snapshot_engine_prepare_state(),
|
||||
"engine_finalize_state": self.owner._snapshot_engine_finalize_state(),
|
||||
"engine_dispatcher_state": self.owner._snapshot_engine_dispatch_state(),
|
||||
"active_batch": dict(worker_state.get("active_batch") or {}),
|
||||
"prepare_state": dict(worker_state.get("prepare_state") or {}),
|
||||
"bert_batch_worker_state": bert_worker_state,
|
||||
"ref_semantic_worker_state": ref_semantic_worker_state,
|
||||
"text_preprocessor_state": text_preprocessor_state,
|
||||
}
|
||||
|
||||
def get_scheduler_state(self) -> dict:
|
||||
return self.scheduler_worker.snapshot()
|
||||
|
||||
def get_runtime_state(self) -> dict:
|
||||
model_state = self.model_registry.snapshot()
|
||||
default_ref = self.reference_registry.get_default()
|
||||
scheduler_state = self.get_scheduler_state()
|
||||
request_registry = self.owner._snapshot_request_registry()
|
||||
engine_policy = self._build_engine_policy_snapshot(request_registry, scheduler_state)
|
||||
engine_arbiter_state = self.owner._snapshot_engine_arbiter_state()
|
||||
engine_decode_runtime_state = self.owner._snapshot_engine_decode_runtime_state()
|
||||
engine_job_registry = self.owner._snapshot_engine_job_registry()
|
||||
engine_prepare_state = self.owner._snapshot_engine_prepare_state()
|
||||
engine_finalize_state = self.owner._snapshot_engine_finalize_state()
|
||||
engine_dispatcher_state = self.owner._snapshot_engine_dispatch_state()
|
||||
engine_drained = self.owner._is_engine_drained()
|
||||
return {
|
||||
"message": "success",
|
||||
"default_reference": {
|
||||
"ref_audio_path": default_ref.ref_audio_path,
|
||||
"updated_at": default_ref.updated_at,
|
||||
},
|
||||
"model_registry": {
|
||||
"generation": model_state.generation,
|
||||
"t2s_generation": model_state.t2s_generation,
|
||||
"vits_generation": model_state.vits_generation,
|
||||
"t2s_weights_path": model_state.t2s_weights_path,
|
||||
"vits_weights_path": model_state.vits_weights_path,
|
||||
"updated_at": model_state.updated_at,
|
||||
},
|
||||
"worker_state": scheduler_state,
|
||||
"engine_policy": engine_policy,
|
||||
"engine_arbiter_state": engine_arbiter_state,
|
||||
"engine_decode_runtime_state": engine_decode_runtime_state,
|
||||
"engine_job_registry": engine_job_registry,
|
||||
"engine_active_batch_state": self.engine_decode_runtime_owner.active_batch_summary(),
|
||||
"engine_prepare_state": engine_prepare_state,
|
||||
"engine_finalize_state": engine_finalize_state,
|
||||
"engine_dispatcher_state": engine_dispatcher_state,
|
||||
"engine_drained": bool(engine_drained),
|
||||
"request_registry": request_registry,
|
||||
"stage_summary": self._build_stage_summary(request_registry, scheduler_state),
|
||||
}
|
||||
|
||||
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
|
||||
if not self.scheduler_worker.wait_until_idle(timeout_sec=timeout_sec):
|
||||
raise TimeoutError("scheduler worker did not drain before model reload")
|
||||
|
||||
def set_refer_audio(self, refer_audio_path: str | None) -> dict:
|
||||
if refer_audio_path in [None, ""]:
|
||||
state = self.reference_registry.clear()
|
||||
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
|
||||
if not os.path.exists(str(refer_audio_path)):
|
||||
raise FileNotFoundError(f"{refer_audio_path} not exists")
|
||||
with self.management_lock:
|
||||
with self.direct_tts_lock:
|
||||
self.tts.set_ref_audio(str(refer_audio_path))
|
||||
state = self.reference_registry.set_default(str(refer_audio_path))
|
||||
return {"message": "success", "default_ref_audio_path": state.ref_audio_path}
|
||||
|
||||
def set_gpt_weights(self, weights_path: str) -> dict:
|
||||
if weights_path in ["", None]:
|
||||
raise ValueError("gpt weight path is required")
|
||||
with self.management_lock:
|
||||
self._wait_for_safe_reload()
|
||||
with self.direct_tts_lock:
|
||||
self.tts.init_t2s_weights(weights_path)
|
||||
self.tts.refresh_runtime_components()
|
||||
state = self.model_registry.mark_t2s_reload(str(weights_path))
|
||||
return {"message": "success", "t2s_generation": state.t2s_generation, "generation": state.generation}
|
||||
|
||||
def set_sovits_weights(self, weights_path: str) -> dict:
|
||||
if weights_path in ["", None]:
|
||||
raise ValueError("sovits weight path is required")
|
||||
with self.management_lock:
|
||||
self._wait_for_safe_reload()
|
||||
with self.direct_tts_lock:
|
||||
self.tts.init_vits_weights(weights_path)
|
||||
self.tts.refresh_runtime_components()
|
||||
state = self.model_registry.mark_vits_reload(str(weights_path))
|
||||
return {"message": "success", "vits_generation": state.vits_generation, "generation": state.generation}
|
||||
|
||||
def handle_control(self, command: str) -> None:
|
||||
if command == "restart":
|
||||
if self.control_callbacks.restart is None:
|
||||
os.execl(sys.executable, sys.executable, *sys.argv)
|
||||
self.control_callbacks.restart()
|
||||
return
|
||||
if command == "exit":
|
||||
if self.control_callbacks.exit is None:
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
return
|
||||
self.control_callbacks.exit()
|
||||
return
|
||||
raise ValueError(f"unsupported command: {command}")
|
||||
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_runtime import EngineRuntimeFacade
|
||||
|
||||
|
||||
class EngineRuntimeDelegates:
|
||||
@staticmethod
|
||||
def _safe_component_snapshot(component: Any) -> Dict[str, Any] | None:
|
||||
return EngineRuntimeFacade._safe_component_snapshot(component)
|
||||
|
||||
def _build_stage_counters(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.runtime_facade._build_stage_counters(request_registry, worker_state)
|
||||
|
||||
def _build_engine_policy_snapshot(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.runtime_facade._build_engine_policy_snapshot(request_registry, worker_state)
|
||||
|
||||
async def _wait_for_engine_policy_admission(
|
||||
self,
|
||||
*,
|
||||
request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
return await self.engine_policy_arbiter.wait_for_policy_admission(
|
||||
request_id=request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def _build_stage_summary(
|
||||
self,
|
||||
request_registry: Dict[str, Any],
|
||||
worker_state: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return self.runtime_facade._build_stage_summary(request_registry, worker_state)
|
||||
|
||||
def _wait_for_safe_reload(self, timeout_sec: float = 300.0) -> None:
|
||||
self.runtime_facade._wait_for_safe_reload(timeout_sec=timeout_sec)
|
||||
172
GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py
Normal file
172
GPT_SoVITS/TTS_infer_pack/unified_engine_stage.py
Normal file
@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
|
||||
EngineDecodeRuntimeOwner,
|
||||
EngineDispatchTask,
|
||||
EngineTaskQueueOwner,
|
||||
SchedulerFinalizeTask,
|
||||
SchedulerPendingJob,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_orchestration import EngineStageOrchestrator
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_executor import EngineStageExecutor
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
|
||||
|
||||
|
||||
class EngineStageCoordinator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tts: TTS,
|
||||
scheduler_worker: UnifiedSchedulerWorker,
|
||||
prepare_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_text_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
|
||||
finalize_queue_owner: EngineTaskQueueOwner,
|
||||
dispatch_queue_owner: EngineTaskQueueOwner,
|
||||
decode_runtime_owner: EngineDecodeRuntimeOwner,
|
||||
update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None],
|
||||
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
|
||||
fail_request_state: Callable[[str, str], None],
|
||||
get_engine_job: Callable[[str], SchedulerPendingJob | None],
|
||||
register_engine_job: Callable[[SchedulerPendingJob], None],
|
||||
fail_engine_jobs: Callable[[List[str], str], None],
|
||||
complete_engine_job: Callable[..., None],
|
||||
add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None],
|
||||
add_engine_merge_time: Callable[[List[str], float], None],
|
||||
add_engine_decode_time: Callable[[List[str], float], None],
|
||||
enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None],
|
||||
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
|
||||
) -> None:
|
||||
self.executor = EngineStageExecutor(
|
||||
tts=tts,
|
||||
scheduler_worker=scheduler_worker,
|
||||
prepare_queue_owner=prepare_queue_owner,
|
||||
prepare_text_queue_owner=prepare_text_queue_owner,
|
||||
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
|
||||
finalize_queue_owner=finalize_queue_owner,
|
||||
dispatch_queue_owner=dispatch_queue_owner,
|
||||
decode_runtime_owner=decode_runtime_owner,
|
||||
update_request_state=update_request_state,
|
||||
merge_request_state_profile=merge_request_state_profile,
|
||||
fail_request_state=fail_request_state,
|
||||
get_engine_job=get_engine_job,
|
||||
register_engine_job=register_engine_job,
|
||||
fail_engine_jobs=fail_engine_jobs,
|
||||
complete_engine_job=complete_engine_job,
|
||||
add_engine_prefill_time=add_engine_prefill_time,
|
||||
add_engine_merge_time=add_engine_merge_time,
|
||||
add_engine_decode_time=add_engine_decode_time,
|
||||
enqueue_engine_finished_items=enqueue_engine_finished_items,
|
||||
snapshot_engine_dispatch_state=snapshot_engine_dispatch_state,
|
||||
snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state,
|
||||
)
|
||||
self.orchestrator = EngineStageOrchestrator(
|
||||
executor=self.executor,
|
||||
scheduler_worker=scheduler_worker,
|
||||
prepare_queue_owner=prepare_queue_owner,
|
||||
prepare_text_queue_owner=prepare_text_queue_owner,
|
||||
prepare_ref_spec_queue_owner=prepare_ref_spec_queue_owner,
|
||||
finalize_queue_owner=finalize_queue_owner,
|
||||
dispatch_queue_owner=dispatch_queue_owner,
|
||||
decode_runtime_owner=decode_runtime_owner,
|
||||
snapshot_engine_decode_runtime_state=snapshot_engine_decode_runtime_state,
|
||||
)
|
||||
|
||||
def bind_arbiter(
|
||||
self,
|
||||
*,
|
||||
notify_arbiter: Callable[[], None],
|
||||
select_stage: Callable[[], tuple[str, str, Dict[str, Any], Dict[str, Any]]],
|
||||
mark_arbiter_tick: Callable[[str, str, bool], None],
|
||||
wait_arbiter: Callable[[], None],
|
||||
) -> None:
|
||||
self.orchestrator.bind_arbiter(
|
||||
notify_arbiter=notify_arbiter,
|
||||
select_stage=select_stage,
|
||||
mark_arbiter_tick=mark_arbiter_tick,
|
||||
wait_arbiter=wait_arbiter,
|
||||
)
|
||||
|
||||
async def prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.executor.prepare_state_via_engine_gpu_queue(
|
||||
spec=spec,
|
||||
prepare_submit_at=prepare_submit_at,
|
||||
engine_request_id=engine_request_id,
|
||||
)
|
||||
|
||||
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
self.executor.enqueue_worker_finished_for_finalize(tasks)
|
||||
|
||||
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
return self.executor.take_engine_finalize_batch_nonblocking()
|
||||
|
||||
async def enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
return await self.executor.enqueue_prepared_state_for_dispatch(
|
||||
state=state,
|
||||
speed_factor=speed_factor,
|
||||
sample_steps=sample_steps,
|
||||
media_type=media_type,
|
||||
super_sampling=super_sampling,
|
||||
prepare_wall_ms=prepare_wall_ms,
|
||||
prepare_profile_total_ms=prepare_profile_total_ms,
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
|
||||
def peek_queue_age_ms(self, queue_name: str) -> float:
|
||||
return self.orchestrator.peek_queue_age_ms(queue_name)
|
||||
|
||||
def has_pending_work(self) -> bool:
|
||||
return self.orchestrator.has_pending_work()
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_once()
|
||||
|
||||
def run_engine_prepare_audio_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_audio_once()
|
||||
|
||||
def run_engine_prepare_text_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_text_once()
|
||||
|
||||
def run_engine_prepare_ref_spec_once(self) -> bool:
|
||||
return self.executor.run_engine_prepare_ref_spec_once()
|
||||
|
||||
def run_engine_finalize_once(self) -> bool:
|
||||
return self.executor.run_engine_finalize_once()
|
||||
|
||||
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, Any], worker_state: Dict[str, Any]) -> bool:
|
||||
return self.executor.run_engine_dispatch_once(policy_snapshot, worker_state)
|
||||
|
||||
def run_engine_decode_runtime_once(self) -> bool:
|
||||
return self.executor.run_engine_decode_runtime_once()
|
||||
|
||||
def run_engine_arbiter_loop(self) -> None:
|
||||
self.orchestrator.run_engine_arbiter_loop()
|
||||
40
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py
Normal file
40
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_decode.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class EngineDecodeStageMixin:
|
||||
def run_engine_decode_runtime_once(self) -> bool:
|
||||
if not self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
return False
|
||||
runtime_state = self.snapshot_engine_decode_runtime_state()
|
||||
pending_jobs = self.decode_runtime_owner.take_pending_jobs_nonblocking(
|
||||
wait_for_batch=int(runtime_state.get("active_request_count", 0)) <= 0
|
||||
)
|
||||
result = self.scheduler_worker.execute_decode_cycle(
|
||||
pending_jobs=pending_jobs,
|
||||
active_batch=self.decode_runtime_owner.get_active_batch(),
|
||||
external_bookkeeping=True,
|
||||
)
|
||||
prefill_phase = dict(result.get("prefill_phase") or {})
|
||||
if prefill_phase.get("error"):
|
||||
self.fail_engine_jobs(list(prefill_phase.get("error_request_ids") or []), str(prefill_phase.get("error")))
|
||||
else:
|
||||
prefill_jobs = list(prefill_phase.get("pending_jobs") or [])
|
||||
self.add_engine_prefill_time(prefill_jobs, float(prefill_phase.get("prefill_elapsed_s", 0.0)))
|
||||
self.add_engine_merge_time(
|
||||
[] if result.get("active_batch") is None else list(result["active_batch"].request_ids),
|
||||
float(prefill_phase.get("merge_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(prefill_phase.get("finished_items") or []))
|
||||
decode_phase = dict(result.get("decode_phase") or {})
|
||||
if decode_phase.get("error"):
|
||||
self.fail_engine_jobs(list(decode_phase.get("error_request_ids") or []), str(decode_phase.get("error")))
|
||||
else:
|
||||
self.add_engine_decode_time(
|
||||
list(decode_phase.get("request_ids") or []),
|
||||
float(decode_phase.get("decode_elapsed_s", 0.0)),
|
||||
)
|
||||
self.enqueue_engine_finished_items(list(decode_phase.get("finished_items") or []))
|
||||
self.decode_runtime_owner.set_active_batch(result.get("active_batch"))
|
||||
if result.get("executed", False):
|
||||
self.decode_runtime_owner.refresh_state("engine_decode_cycle")
|
||||
return bool(result.get("executed", False))
|
||||
100
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py
Normal file
100
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_dispatch.py
Normal file
@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask
|
||||
|
||||
|
||||
class EngineDispatchStageMixin:
|
||||
async def enqueue_prepared_state_for_dispatch(
|
||||
self,
|
||||
*,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None,
|
||||
done_future: asyncio.Future | None,
|
||||
engine_request_id: str | None,
|
||||
timeout_sec: float | None,
|
||||
) -> EngineDispatchTask:
|
||||
if float(state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
|
||||
error = RuntimeError("ref_spec async stage failed before dispatch")
|
||||
self.fail_request_state(engine_request_id or state.request_id, str(error))
|
||||
raise error
|
||||
task = EngineDispatchTask(
|
||||
request_id=state.request_id,
|
||||
state=state,
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
super_sampling=bool(super_sampling),
|
||||
prepare_wall_ms=float(prepare_wall_ms),
|
||||
prepare_profile_total_ms=float(prepare_profile_total_ms),
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or state.request_id,
|
||||
timeout_sec=timeout_sec,
|
||||
enqueue_time=time.perf_counter(),
|
||||
)
|
||||
self.dispatch_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
self.merge_request_state_profile(
|
||||
task.engine_request_id or task.request_id,
|
||||
{
|
||||
"engine_dispatch_queue_depth_on_enqueue": int(
|
||||
self.snapshot_engine_dispatch_state()["waiting_count"]
|
||||
),
|
||||
},
|
||||
)
|
||||
return task
|
||||
|
||||
def run_engine_dispatch_once(self, policy_snapshot: Dict[str, object], worker_state: Dict[str, object]) -> bool:
|
||||
if not bool(policy_snapshot.get("allowed", True)):
|
||||
return False
|
||||
dispatch_task = self.dispatch_queue_owner.pop_left()
|
||||
if dispatch_task is None:
|
||||
return False
|
||||
dispatched_at = time.perf_counter()
|
||||
dispatch_wait_ms = max(0.0, (dispatched_at - dispatch_task.enqueue_time) * 1000.0)
|
||||
dispatch_task.engine_policy_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_dispatch_wait_ms = float(dispatch_wait_ms)
|
||||
dispatch_task.engine_policy_snapshot = dict(policy_snapshot)
|
||||
try:
|
||||
worker_job = self.scheduler_worker.submit(
|
||||
state=dispatch_task.state,
|
||||
speed_factor=dispatch_task.speed_factor,
|
||||
sample_steps=dispatch_task.sample_steps,
|
||||
media_type=dispatch_task.media_type,
|
||||
super_sampling=dispatch_task.super_sampling,
|
||||
prepare_wall_ms=dispatch_task.prepare_wall_ms,
|
||||
prepare_profile_total_ms=dispatch_task.prepare_profile_total_ms,
|
||||
done_loop=dispatch_task.done_loop,
|
||||
done_future=dispatch_task.done_future,
|
||||
engine_request_id=dispatch_task.engine_request_id,
|
||||
timeout_sec=dispatch_task.timeout_sec,
|
||||
skip_capacity_wait=True,
|
||||
admission_wait_ms_override=0.0,
|
||||
admission_snapshot_override=dict(worker_state),
|
||||
engine_policy_wait_ms=dispatch_task.engine_policy_wait_ms,
|
||||
engine_dispatch_wait_ms=dispatch_task.engine_dispatch_wait_ms,
|
||||
enqueue_pending=not self.scheduler_worker.is_engine_decode_control_enabled(),
|
||||
)
|
||||
dispatch_task.worker_job = worker_job
|
||||
self.register_engine_job(worker_job)
|
||||
if self.scheduler_worker.is_engine_decode_control_enabled():
|
||||
self.decode_runtime_owner.enqueue_pending_job(worker_job)
|
||||
self.notify_arbiter()
|
||||
self.dispatch_queue_owner.mark_completed(1)
|
||||
return True
|
||||
except Exception as exc:
|
||||
dispatch_task.error = str(exc)
|
||||
self.fail_request_state(dispatch_task.engine_request_id or dispatch_task.request_id, str(exc))
|
||||
self._notify_dispatch_error(dispatch_task, exc)
|
||||
return True
|
||||
74
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py
Normal file
74
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_executor.py
Normal file
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import (
|
||||
EngineDecodeRuntimeOwner,
|
||||
EngineTaskQueueOwner,
|
||||
SchedulerFinalizeTask,
|
||||
SchedulerPendingJob,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_decode import EngineDecodeStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_dispatch import EngineDispatchStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_finalize import EngineFinalizeStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_futures import EngineStageFutureMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_stage_prepare import EnginePrepareStageMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker import UnifiedSchedulerWorker
|
||||
|
||||
|
||||
class EngineStageExecutor(
|
||||
EngineStageFutureMixin,
|
||||
EnginePrepareStageMixin,
|
||||
EngineFinalizeStageMixin,
|
||||
EngineDispatchStageMixin,
|
||||
EngineDecodeStageMixin,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tts: TTS,
|
||||
scheduler_worker: UnifiedSchedulerWorker,
|
||||
prepare_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_text_queue_owner: EngineTaskQueueOwner,
|
||||
prepare_ref_spec_queue_owner: EngineTaskQueueOwner,
|
||||
finalize_queue_owner: EngineTaskQueueOwner,
|
||||
dispatch_queue_owner: EngineTaskQueueOwner,
|
||||
decode_runtime_owner: EngineDecodeRuntimeOwner,
|
||||
update_request_state: Callable[[str, str, Optional[Dict[str, Any]]], None],
|
||||
merge_request_state_profile: Callable[[str, Optional[Dict[str, Any]]], None],
|
||||
fail_request_state: Callable[[str, str], None],
|
||||
get_engine_job: Callable[[str], SchedulerPendingJob | None],
|
||||
register_engine_job: Callable[[SchedulerPendingJob], None],
|
||||
fail_engine_jobs: Callable[[List[str], str], None],
|
||||
complete_engine_job: Callable[..., None],
|
||||
add_engine_prefill_time: Callable[[List[SchedulerPendingJob], float], None],
|
||||
add_engine_merge_time: Callable[[List[str], float], None],
|
||||
add_engine_decode_time: Callable[[List[str], float], None],
|
||||
enqueue_engine_finished_items: Callable[[List[T2SFinishedItem]], None],
|
||||
snapshot_engine_dispatch_state: Callable[[], Dict[str, Any]],
|
||||
snapshot_engine_decode_runtime_state: Callable[[], Dict[str, Any]],
|
||||
) -> None:
|
||||
self.tts = tts
|
||||
self.scheduler_worker = scheduler_worker
|
||||
self.prepare_queue_owner = prepare_queue_owner
|
||||
self.prepare_text_queue_owner = prepare_text_queue_owner
|
||||
self.prepare_ref_spec_queue_owner = prepare_ref_spec_queue_owner
|
||||
self.finalize_queue_owner = finalize_queue_owner
|
||||
self.dispatch_queue_owner = dispatch_queue_owner
|
||||
self.decode_runtime_owner = decode_runtime_owner
|
||||
self.update_request_state = update_request_state
|
||||
self.merge_request_state_profile = merge_request_state_profile
|
||||
self.fail_request_state = fail_request_state
|
||||
self.get_engine_job = get_engine_job
|
||||
self.register_engine_job = register_engine_job
|
||||
self.fail_engine_jobs = fail_engine_jobs
|
||||
self.complete_engine_job = complete_engine_job
|
||||
self.add_engine_prefill_time = add_engine_prefill_time
|
||||
self.add_engine_merge_time = add_engine_merge_time
|
||||
self.add_engine_decode_time = add_engine_decode_time
|
||||
self.enqueue_engine_finished_items = enqueue_engine_finished_items
|
||||
self.snapshot_engine_dispatch_state = snapshot_engine_dispatch_state
|
||||
self.snapshot_engine_decode_runtime_state = snapshot_engine_decode_runtime_state
|
||||
self._notify_arbiter: Callable[[], None] | None = None
|
||||
103
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py
Normal file
103
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_finalize.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class EngineFinalizeStageMixin:
|
||||
def enqueue_worker_finished_for_finalize(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.READY_FOR_FINALIZE,
|
||||
{
|
||||
"finish_reason": task.item.finish_reason,
|
||||
"semantic_len": int(task.item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(task.item.finish_idx),
|
||||
},
|
||||
)
|
||||
self.finalize_queue_owner.enqueue_many(tasks)
|
||||
self.notify_arbiter()
|
||||
|
||||
def take_engine_finalize_batch_nonblocking(self) -> List[SchedulerFinalizeTask]:
|
||||
finalize_policy = self.scheduler_worker.get_finalize_batch_policy()
|
||||
return self.finalize_queue_owner.take_finalize_batch(
|
||||
finalize_mode=str(finalize_policy.get("finalize_mode", "async")),
|
||||
batch_max_items=int(finalize_policy.get("finalize_batch_max_items", 1)),
|
||||
batch_wait_s=float(finalize_policy.get("finalize_batch_wait_s", 0.0)),
|
||||
use_vocoder=bool(self.tts.configs.use_vocoder),
|
||||
)
|
||||
|
||||
def run_engine_finalize_once(self) -> bool:
|
||||
tasks = self.take_engine_finalize_batch_nonblocking()
|
||||
if not tasks:
|
||||
return False
|
||||
ready_tasks: List[SchedulerFinalizeTask] = []
|
||||
failed_tasks: List[SchedulerFinalizeTask] = []
|
||||
deferred_tasks: List[SchedulerFinalizeTask] = []
|
||||
for task in tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
if float(job.state.prepare_profile.get("ref_spec_async_failed", 0.0) or 0.0) > 0.0:
|
||||
failed_tasks.append(task)
|
||||
continue
|
||||
if job.state.refer_spec is None:
|
||||
deferred_tasks.append(task)
|
||||
self.merge_request_state_profile(
|
||||
job.engine_request_id or job.request_id,
|
||||
{
|
||||
"engine_finalize_ref_spec_blocked": 1.0,
|
||||
},
|
||||
)
|
||||
continue
|
||||
ready_tasks.append(task)
|
||||
if deferred_tasks:
|
||||
self.finalize_queue_owner.enqueue_many(deferred_tasks)
|
||||
if failed_tasks:
|
||||
self.fail_engine_jobs([task.request_id for task in failed_tasks], "ref_spec async stage failed")
|
||||
if not ready_tasks:
|
||||
self.finalize_queue_owner.mark_completed(len(failed_tasks), notify=True)
|
||||
return False
|
||||
self.scheduler_worker.begin_finalize_execution(len(ready_tasks))
|
||||
try:
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
for task in ready_tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
jobs_and_items.append((job, task.item))
|
||||
if not jobs_and_items:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
for task in ready_tasks:
|
||||
job = self.get_engine_job(task.request_id)
|
||||
if job is not None:
|
||||
job.finalize_wait_ms += max(0.0, (now - task.enqueued_time) * 1000.0)
|
||||
for job, item in jobs_and_items:
|
||||
self.update_request_state(
|
||||
job.engine_request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
},
|
||||
)
|
||||
synth_ms, batch_results = self.scheduler_worker.synthesize_finalize_jobs(jobs_and_items)
|
||||
for job, _ in jobs_and_items:
|
||||
job.synth_ms += float(synth_ms)
|
||||
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
|
||||
self.complete_engine_job(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
except Exception as exc:
|
||||
self.fail_engine_jobs([task.request_id for task in ready_tasks], str(exc))
|
||||
finally:
|
||||
self.scheduler_worker.end_finalize_execution(len(ready_tasks))
|
||||
self.finalize_queue_owner.mark_completed(len(ready_tasks) + len(failed_tasks), notify=True)
|
||||
return True
|
||||
59
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py
Normal file
59
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_futures.py
Normal file
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineDispatchTask, EngineGpuPrepareTask
|
||||
|
||||
|
||||
class EngineStageFutureMixin:
|
||||
def bind_notify_arbiter(self, notify_arbiter: Callable[[], None]) -> None:
|
||||
self._notify_arbiter = notify_arbiter
|
||||
|
||||
def notify_arbiter(self) -> None:
|
||||
if self._notify_arbiter is not None:
|
||||
self._notify_arbiter()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dispatch_error_future(future: asyncio.Future, error: Exception) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_exception(error)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_prepare_future(
|
||||
future: asyncio.Future,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if future.done():
|
||||
return
|
||||
future.set_result(payload)
|
||||
|
||||
def _notify_dispatch_error(self, task: EngineDispatchTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _notify_prepare_error(self, task: EngineGpuPrepareTask, error: Exception) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_dispatch_error_future, task.done_future, error)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _notify_prepare_result(
|
||||
self,
|
||||
task: EngineGpuPrepareTask,
|
||||
payload: tuple[T2SRequestState, float, float],
|
||||
) -> None:
|
||||
if task.done_loop is None or task.done_future is None:
|
||||
return
|
||||
try:
|
||||
task.done_loop.call_soon_threadsafe(self._resolve_prepare_future, task.done_future, payload)
|
||||
except RuntimeError:
|
||||
pass
|
||||
306
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py
Normal file
306
GPT_SoVITS/TTS_infer_pack/unified_engine_stage_prepare.py
Normal file
@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineGpuPrepareTask, EngineStatus
|
||||
|
||||
|
||||
class EnginePrepareStageMixin:
|
||||
def _prepare_waiting_total(self) -> int:
|
||||
return (
|
||||
int(self.prepare_queue_owner.waiting_count())
|
||||
+ int(self.prepare_text_queue_owner.waiting_count())
|
||||
+ int(self.prepare_ref_spec_queue_owner.waiting_count())
|
||||
)
|
||||
|
||||
async def _wait_prepare_queue_admission(self) -> float:
|
||||
soft_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_SOFT_MAX", "0")))
|
||||
if soft_max <= 0:
|
||||
return 0.0
|
||||
poll_s = max(
|
||||
0.0005,
|
||||
float(max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_QUEUE_ADMISSION_POLL_MS", "1")))) / 1000.0,
|
||||
)
|
||||
wait_start = time.perf_counter()
|
||||
while self._prepare_waiting_total() >= soft_max:
|
||||
await asyncio.sleep(poll_s)
|
||||
return max(0.0, (time.perf_counter() - wait_start) * 1000.0)
|
||||
|
||||
async def prepare_state_via_engine_gpu_queue(
|
||||
self,
|
||||
*,
|
||||
spec: Any,
|
||||
prepare_submit_at: float,
|
||||
engine_request_id: str | None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
prepare_queue_admission_wait_ms = await self._wait_prepare_queue_admission()
|
||||
cpu_stage = await self.scheduler_worker.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
|
||||
if engine_request_id not in [None, ""]:
|
||||
self.update_request_state(
|
||||
str(engine_request_id),
|
||||
EngineStatus.GPU_PREPARING,
|
||||
{
|
||||
"engine_prepare_queue_admission_wait_ms": float(prepare_queue_admission_wait_ms),
|
||||
"prompt_text_cpu_queue_ms": float(cpu_stage.prompt_cpu_profiled.queue_ms),
|
||||
"prompt_text_cpu_run_ms": float(cpu_stage.prompt_cpu_profiled.run_ms),
|
||||
"text_cpu_queue_ms": float(cpu_stage.target_cpu_profiled.queue_ms),
|
||||
"text_cpu_run_ms": float(cpu_stage.target_cpu_profiled.run_ms),
|
||||
},
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
done_future = loop.create_future()
|
||||
task = EngineGpuPrepareTask(
|
||||
request_id=spec.request_id,
|
||||
cpu_stage=cpu_stage,
|
||||
done_loop=loop,
|
||||
done_future=done_future,
|
||||
engine_request_id=engine_request_id or spec.request_id,
|
||||
enqueue_time=time.perf_counter(),
|
||||
phase="audio",
|
||||
audio_enqueue_time=time.perf_counter(),
|
||||
admission_wait_ms=float(prepare_queue_admission_wait_ms),
|
||||
)
|
||||
self.prepare_queue_owner.enqueue(task)
|
||||
self.notify_arbiter()
|
||||
return await done_future
|
||||
|
||||
def _should_chain_prepare_text_after_audio(self) -> bool:
|
||||
if str(os.environ.get("GPTSOVITS_ENGINE_PREPARE_CHAIN_TEXT", "1")).strip().lower() in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
if self.finalize_queue_owner.has_items() or self.dispatch_queue_owner.has_items():
|
||||
return False
|
||||
decode_runtime_state = self.snapshot_engine_decode_runtime_state()
|
||||
if bool(decode_runtime_state.get("has_work", False)):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _maybe_apply_ref_spec_to_state(self, task: EngineGpuPrepareTask) -> None:
|
||||
if task.state_result is None or task.ref_spec_result is None:
|
||||
return
|
||||
self.scheduler_worker.apply_ref_spec_result_to_state(task.state_result, task.ref_spec_result)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"engine_prepare_ref_spec_queue_wait_ms": float(task.ref_spec_queue_wait_ms),
|
||||
"ref_spec_wait_ms": float(task.ref_spec_result[1].get("ref_spec_wait_ms", 0.0)),
|
||||
"ref_spec_ms": float(task.ref_spec_result[1].get("ref_spec_ms", 0.0)),
|
||||
"ref_spec_to_device_ms": float(task.ref_spec_result[1].get("ref_spec_to_device_ms", 0.0)),
|
||||
"ref_spec_main_resample_ms": float(task.ref_spec_result[1].get("ref_spec_main_resample_ms", 0.0)),
|
||||
"ref_spec_norm_ms": float(task.ref_spec_result[1].get("ref_spec_norm_ms", 0.0)),
|
||||
"ref_spec_spectrogram_ms": float(task.ref_spec_result[1].get("ref_spec_spectrogram_ms", 0.0)),
|
||||
"ref_spec_post_resample_ms": float(task.ref_spec_result[1].get("ref_spec_post_resample_ms", 0.0)),
|
||||
},
|
||||
)
|
||||
|
||||
def _mark_ref_spec_async_failed(
|
||||
self,
|
||||
task: EngineGpuPrepareTask,
|
||||
error: Exception,
|
||||
*,
|
||||
queue_wait_ms: float,
|
||||
) -> None:
|
||||
task.error = str(error)
|
||||
task.cancelled = True
|
||||
if task.state_result is not None:
|
||||
task.state_result.prepare_profile["ref_spec_async_failed"] = 1.0
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"ref_spec_async_failed": 1.0,
|
||||
"engine_prepare_ref_spec_queue_wait_ms": float(queue_wait_ms),
|
||||
},
|
||||
)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(error))
|
||||
self.fail_engine_jobs([task.request_id], str(error))
|
||||
self.notify_arbiter()
|
||||
|
||||
def _run_engine_prepare_audio_once(self, batch_max_items: int) -> bool:
|
||||
tasks = self.prepare_queue_owner.pop_left_many(batch_max_items)
|
||||
if not tasks:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
|
||||
for task in tasks:
|
||||
task.audio_start_time = float(now)
|
||||
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_audio_phases_async([task.cpu_stage for task in tasks]))
|
||||
completed_count = 0
|
||||
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
|
||||
task.audio_end_time = time.perf_counter()
|
||||
if isinstance(result, Exception):
|
||||
task.error = str(result)
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
|
||||
self._notify_prepare_error(task, result)
|
||||
completed_count += 1
|
||||
continue
|
||||
task.audio_queue_wait_ms = float(queue_wait_ms)
|
||||
task.phase_one = result
|
||||
task.phase = "text"
|
||||
task.enqueue_time = time.perf_counter()
|
||||
task.text_enqueue_time = float(task.enqueue_time)
|
||||
task.ref_spec_enqueue_time = float(task.enqueue_time)
|
||||
self.prepare_text_queue_owner.enqueue(task)
|
||||
self.prepare_ref_spec_queue_owner.enqueue(task)
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"engine_prepare_audio_queue_wait_ms": float(queue_wait_ms),
|
||||
"engine_prepare_audio_batch_size": float(len(tasks)),
|
||||
"engine_prepare_audio_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
|
||||
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
|
||||
"engine_prepare_audio_start_ts": float(task.audio_start_time),
|
||||
"engine_prepare_audio_end_ts": float(task.audio_end_time),
|
||||
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
|
||||
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
|
||||
},
|
||||
)
|
||||
completed_count += 1
|
||||
self.prepare_queue_owner.mark_completed(completed_count)
|
||||
if completed_count > 0 and self._should_chain_prepare_text_after_audio():
|
||||
self._run_engine_prepare_text_once(min(batch_max_items, completed_count))
|
||||
return True
|
||||
if completed_count > 0:
|
||||
self.notify_arbiter()
|
||||
return True
|
||||
|
||||
def _run_engine_prepare_text_once(self, batch_max_items: int) -> bool:
|
||||
tasks = self.prepare_text_queue_owner.pop_left_many(batch_max_items)
|
||||
if not tasks:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
queue_wait_ms_list = [max(0.0, (now - task.enqueue_time) * 1000.0) for task in tasks]
|
||||
for task in tasks:
|
||||
task.text_start_time = float(now)
|
||||
items = [(task.cpu_stage, task.phase_one) for task in tasks if task.phase_one is not None]
|
||||
batch_results = asyncio.run(self.scheduler_worker.prepare_gpu_text_phases_async(items))
|
||||
completed_count = 0
|
||||
for task, queue_wait_ms, result in zip(tasks, queue_wait_ms_list, batch_results):
|
||||
task.text_end_time = time.perf_counter()
|
||||
if isinstance(result, Exception):
|
||||
task.error = str(result)
|
||||
task.cancelled = True
|
||||
self.fail_request_state(task.engine_request_id or task.request_id, str(result))
|
||||
self._notify_prepare_error(task, result)
|
||||
completed_count += 1
|
||||
continue
|
||||
task.text_queue_wait_ms = float(queue_wait_ms)
|
||||
state, prepare_exec_started_at, prepare_exec_finished_at = self.scheduler_worker.build_gpu_prepare_result_from_phases(
|
||||
task.cpu_stage,
|
||||
task.phase_one or {},
|
||||
result,
|
||||
extra_profile={
|
||||
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
|
||||
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
|
||||
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
|
||||
"engine_prepare_audio_batch_size": float(len(tasks)),
|
||||
"engine_prepare_text_batch_size": float(len(tasks)),
|
||||
"engine_prepare_audio_phase_mode": 2.0,
|
||||
"engine_prepare_audio_phase_wall_ms": float((task.phase_one or {}).get("phase_wall_ms", 0.0)),
|
||||
"engine_prepare_text_phase_wall_ms": float(result.get("phase_wall_ms", 0.0)),
|
||||
"engine_prepare_text_phase_batch_size": float(len(tasks)),
|
||||
"engine_prepare_audio_enqueue_ts": float(task.audio_enqueue_time),
|
||||
"engine_prepare_audio_start_ts": float(task.audio_start_time),
|
||||
"engine_prepare_audio_end_ts": float(task.audio_end_time),
|
||||
"engine_prepare_text_enqueue_ts": float(task.text_enqueue_time),
|
||||
"engine_prepare_text_start_ts": float(task.text_start_time),
|
||||
"engine_prepare_text_end_ts": float(task.text_end_time),
|
||||
"engine_prepare_ref_spec_enqueue_ts": float(task.ref_spec_enqueue_time),
|
||||
},
|
||||
)
|
||||
task.state_result = state
|
||||
self._maybe_apply_ref_spec_to_state(task)
|
||||
state.prepare_profile["engine_gpu_prepare_batch_size"] = float(len(tasks))
|
||||
if task.engine_request_id not in [None, ""]:
|
||||
self.merge_request_state_profile(
|
||||
str(task.engine_request_id),
|
||||
{
|
||||
"engine_prepare_queue_admission_wait_ms": float(task.admission_wait_ms),
|
||||
"engine_prepare_audio_queue_wait_ms": float(task.audio_queue_wait_ms),
|
||||
"engine_prepare_text_queue_wait_ms": float(task.text_queue_wait_ms),
|
||||
"engine_gpu_prepare_queue_wait_ms": float(task.audio_queue_wait_ms + task.text_queue_wait_ms),
|
||||
"engine_gpu_prepare_batch_size": float(len(tasks)),
|
||||
},
|
||||
)
|
||||
self._notify_prepare_result(task, (state, prepare_exec_started_at, prepare_exec_finished_at))
|
||||
completed_count += 1
|
||||
self.prepare_text_queue_owner.mark_completed(completed_count)
|
||||
return True
|
||||
|
||||
def _run_engine_prepare_ref_spec_once(self, batch_max_items: int) -> bool:
|
||||
tasks = self.prepare_ref_spec_queue_owner.pop_left_many(batch_max_items)
|
||||
if not tasks:
|
||||
return False
|
||||
now = time.perf_counter()
|
||||
runnable_tasks: list[EngineGpuPrepareTask] = []
|
||||
queue_wait_ms_list: list[float] = []
|
||||
completed_count = 0
|
||||
for task in tasks:
|
||||
if task.cancelled or task.phase_one is None:
|
||||
completed_count += 1
|
||||
continue
|
||||
task.ref_spec_start_time = float(now)
|
||||
runnable_tasks.append(task)
|
||||
queue_wait_ms_list.append(max(0.0, (now - task.ref_spec_enqueue_time) * 1000.0))
|
||||
if not runnable_tasks:
|
||||
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
|
||||
return True
|
||||
batch_results = asyncio.run(
|
||||
self.scheduler_worker.prepare_ref_spec_stages_async([task.phase_one or {} for task in runnable_tasks])
|
||||
)
|
||||
for task, queue_wait_ms, result in zip(runnable_tasks, queue_wait_ms_list, batch_results):
|
||||
task.ref_spec_end_time = time.perf_counter()
|
||||
task.ref_spec_queue_wait_ms = float(queue_wait_ms)
|
||||
if isinstance(result, Exception):
|
||||
self._mark_ref_spec_async_failed(task, result, queue_wait_ms=float(queue_wait_ms))
|
||||
completed_count += 1
|
||||
continue
|
||||
task.ref_spec_result = result
|
||||
self._maybe_apply_ref_spec_to_state(task)
|
||||
if task.state_result is not None:
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_queue_wait_ms"] = float(queue_wait_ms)
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_enqueue_ts"] = float(task.ref_spec_enqueue_time)
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_start_ts"] = float(task.ref_spec_start_time)
|
||||
task.state_result.prepare_profile["engine_prepare_ref_spec_end_ts"] = float(task.ref_spec_end_time)
|
||||
completed_count += 1
|
||||
self.prepare_ref_spec_queue_owner.mark_completed(completed_count)
|
||||
return True
|
||||
|
||||
def run_engine_prepare_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
batch_max_items = int(prepare_batch_policy.get("prepare_batch_max_items", 1))
|
||||
audio_age_ms = self.prepare_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
text_age_ms = self.prepare_text_queue_owner.peek_oldest_age_ms("enqueue_time")
|
||||
if self.prepare_text_queue_owner.has_items() and (
|
||||
not self.prepare_queue_owner.has_items() or text_age_ms >= audio_age_ms
|
||||
):
|
||||
return self._run_engine_prepare_text_once(batch_max_items)
|
||||
if self.prepare_queue_owner.has_items():
|
||||
return self._run_engine_prepare_audio_once(batch_max_items)
|
||||
if self.prepare_ref_spec_queue_owner.has_items():
|
||||
return self._run_engine_prepare_ref_spec_once(batch_max_items)
|
||||
if self.prepare_text_queue_owner.has_items():
|
||||
return self._run_engine_prepare_text_once(batch_max_items)
|
||||
if self.prepare_ref_spec_queue_owner.has_items():
|
||||
return self._run_engine_prepare_ref_spec_once(batch_max_items)
|
||||
return False
|
||||
|
||||
def run_engine_prepare_audio_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
return self._run_engine_prepare_audio_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
|
||||
def run_engine_prepare_text_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
return self._run_engine_prepare_text_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
|
||||
def run_engine_prepare_ref_spec_once(self) -> bool:
|
||||
prepare_batch_policy = self.scheduler_worker.get_prepare_batch_policy()
|
||||
return self._run_engine_prepare_ref_spec_once(int(prepare_batch_policy.get("prepare_batch_max_items", 1)))
|
||||
71
GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py
Normal file
71
GPT_SoVITS/TTS_infer_pack/unified_engine_worker.py
Normal file
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Callable, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerFinalizeTask, SchedulerJobRegistry
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_completion import WorkerCompletionBridge
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_decode import WorkerDecodeExecutor, WorkerDecodeLegacyShell, WorkerDecodeRuntimeTracker
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_execution import WorkerExecutionMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_finalize import WorkerFinalizeExecutor
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_prepare import WorkerPrepareExecutor
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_runtime import WorkerRuntimeBookkeepingMixin
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_worker_submit import WorkerSubmitLifecycleMixin
|
||||
|
||||
|
||||
class UnifiedSchedulerWorker(
|
||||
WorkerSubmitLifecycleMixin,
|
||||
WorkerRuntimeBookkeepingMixin,
|
||||
WorkerExecutionMixin,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
tts: TTS,
|
||||
max_steps: int = 1500,
|
||||
micro_batch_wait_ms: int = 5,
|
||||
runtime_callbacks: RuntimeStateCallbacks | None = None,
|
||||
external_finalize_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None,
|
||||
):
|
||||
self.tts = tts
|
||||
self.max_steps = int(max_steps)
|
||||
self.micro_batch_wait_s = float(micro_batch_wait_ms) / 1000.0
|
||||
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
|
||||
self.condition = threading.Condition()
|
||||
self.completion_bridge = WorkerCompletionBridge(self.runtime_callbacks)
|
||||
self.decode_executor = WorkerDecodeExecutor(tts, max_steps=max_steps)
|
||||
self.decode_legacy_shell = WorkerDecodeLegacyShell(self.condition, self.micro_batch_wait_s)
|
||||
self.decode_runtime_tracker = WorkerDecodeRuntimeTracker(self.runtime_callbacks)
|
||||
self.prepare_executor = WorkerPrepareExecutor(tts, on_state_change=self._notify_worker_state_change)
|
||||
self.finalize_executor = WorkerFinalizeExecutor(
|
||||
tts,
|
||||
on_state_change=self._notify_worker_state_change,
|
||||
external_submit=external_finalize_submit,
|
||||
)
|
||||
self.decode_backlog_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_DECODE_BACKLOG_MAX", "0")))
|
||||
self.finalize_pending_max = max(0, int(os.environ.get("GPTSOVITS_ENGINE_FINALIZE_PENDING_MAX", "0")))
|
||||
self.engine_decode_control_enabled = (
|
||||
str(os.environ.get("GPTSOVITS_ENGINE_DRIVE_DECODE", "1")).strip().lower() in {"1", "true", "yes", "on"}
|
||||
)
|
||||
self.job_registry = SchedulerJobRegistry(self.condition)
|
||||
self.worker_thread: threading.Thread | None = None
|
||||
if not self.engine_decode_control_enabled:
|
||||
self.worker_thread = threading.Thread(target=self._run_loop, name="unified-t2s-scheduler-worker", daemon=True)
|
||||
self.worker_thread.start()
|
||||
self.finalize_threads = []
|
||||
if external_finalize_submit is None:
|
||||
self.finalize_threads = [
|
||||
threading.Thread(
|
||||
target=self._run_finalize_loop,
|
||||
name=f"unified-t2s-finalize-{worker_index}",
|
||||
daemon=True,
|
||||
)
|
||||
for worker_index in range(self.finalize_executor.get_worker_count())
|
||||
]
|
||||
for finalize_thread in self.finalize_threads:
|
||||
finalize_thread.start()
|
||||
|
||||
def _notify_worker_state_change(self) -> None:
|
||||
with self.condition:
|
||||
self.condition.notify_all()
|
||||
198
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py
Normal file
198
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_completion.py
Normal file
@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerJobRegistry, SchedulerPendingJob
|
||||
|
||||
|
||||
class WorkerCompletionBridge:
|
||||
def __init__(self, runtime_callbacks: RuntimeStateCallbacks | None = None) -> None:
|
||||
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(job: SchedulerPendingJob) -> None:
|
||||
future = job.done_future
|
||||
if future is None or future.done():
|
||||
return
|
||||
future.set_result(job)
|
||||
|
||||
def notify_done_future(self, job: SchedulerPendingJob) -> None:
|
||||
if job.done_loop is None or job.done_future is None:
|
||||
return
|
||||
try:
|
||||
job.done_loop.call_soon_threadsafe(self._resolve_done_future, job)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
if request_id is None or self.runtime_callbacks.complete is None:
|
||||
return
|
||||
self.runtime_callbacks.complete(request_id, extra)
|
||||
|
||||
def runtime_fail(self, request_id: str | None, error: str) -> None:
|
||||
if request_id is None or self.runtime_callbacks.fail is None:
|
||||
return
|
||||
self.runtime_callbacks.fail(request_id, error)
|
||||
|
||||
@staticmethod
|
||||
def build_completed_job_result(
|
||||
job: SchedulerPendingJob,
|
||||
item: T2SFinishedItem,
|
||||
*,
|
||||
sample_rate: int,
|
||||
audio_data: np.ndarray,
|
||||
finished_at: float | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
finished_at = float(time.perf_counter() if finished_at is None else finished_at)
|
||||
queue_wait_ms = 0.0
|
||||
if job.first_schedule_time is not None:
|
||||
queue_wait_ms = max(0.0, (job.first_schedule_time - job.enqueue_time) * 1000.0)
|
||||
worker_total_ms = max(0.0, (finished_at - job.enqueue_time) * 1000.0)
|
||||
worker_residual_ms = max(
|
||||
0.0,
|
||||
worker_total_ms
|
||||
- queue_wait_ms
|
||||
- job.prefill_ms
|
||||
- job.merge_ms
|
||||
- job.decode_ms
|
||||
- job.finalize_wait_ms
|
||||
- job.synth_ms,
|
||||
)
|
||||
worker_other_ms = max(0.0, job.merge_ms + job.finalize_wait_ms + worker_residual_ms)
|
||||
job.sample_rate = int(sample_rate)
|
||||
job.audio_data = audio_data
|
||||
job.result_ready_time = finished_at
|
||||
prepare_profile = dict(job.state.prepare_profile)
|
||||
result = {
|
||||
"request_id": item.request_id,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
"decode_admission_wait_ms": float(job.admission_wait_ms),
|
||||
"engine_policy_wait_ms": float(job.engine_policy_wait_ms),
|
||||
"engine_dispatch_wait_ms": float(job.engine_dispatch_wait_ms),
|
||||
"prepare_ms": job.prepare_wall_ms,
|
||||
"prepare_wall_ms": job.prepare_wall_ms,
|
||||
"prepare_profile_total_ms": job.prepare_profile_total_ms,
|
||||
"prepare_profile": prepare_profile,
|
||||
"queue_wait_ms": queue_wait_ms,
|
||||
"prefill_ms": job.prefill_ms,
|
||||
"merge_ms": job.merge_ms,
|
||||
"decode_ms": job.decode_ms,
|
||||
"finalize_wait_ms": job.finalize_wait_ms,
|
||||
"synth_ms": job.synth_ms,
|
||||
"worker_residual_ms": worker_residual_ms,
|
||||
"worker_other_ms": worker_other_ms,
|
||||
"worker_total_ms": worker_total_ms,
|
||||
"decode_steps": int(job.decode_steps),
|
||||
"sample_rate": int(sample_rate),
|
||||
"media_type": job.media_type,
|
||||
}
|
||||
job.result = result
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def build_runtime_complete_payload(
|
||||
job: SchedulerPendingJob,
|
||||
item: T2SFinishedItem,
|
||||
*,
|
||||
sample_rate: int,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"sample_rate": int(sample_rate),
|
||||
"worker_profile": dict(job.result or {}),
|
||||
}
|
||||
|
||||
def complete_job(
|
||||
self,
|
||||
job: SchedulerPendingJob,
|
||||
*,
|
||||
runtime_request_id: str | None,
|
||||
runtime_extra: Optional[Dict[str, Any]] = None,
|
||||
remove_job: Callable[[], None] | None = None,
|
||||
on_job_finished: Callable[[], None] | None = None,
|
||||
notify_waiters: Callable[[], None] | None = None,
|
||||
) -> None:
|
||||
job.done_event.set()
|
||||
self.notify_done_future(job)
|
||||
if remove_job is not None:
|
||||
remove_job()
|
||||
if on_job_finished is not None:
|
||||
on_job_finished()
|
||||
if notify_waiters is not None:
|
||||
notify_waiters()
|
||||
self.runtime_complete(runtime_request_id, runtime_extra)
|
||||
|
||||
def fail_job(
|
||||
self,
|
||||
job: SchedulerPendingJob,
|
||||
*,
|
||||
error: str,
|
||||
remove_job: Callable[[], None] | None = None,
|
||||
on_job_finished: Callable[[], None] | None = None,
|
||||
notify_waiters: Callable[[], None] | None = None,
|
||||
) -> None:
|
||||
job.error = str(error)
|
||||
job.done_event.set()
|
||||
self.notify_done_future(job)
|
||||
if remove_job is not None:
|
||||
remove_job()
|
||||
if on_job_finished is not None:
|
||||
on_job_finished()
|
||||
if notify_waiters is not None:
|
||||
notify_waiters()
|
||||
self.runtime_fail(job.engine_request_id, str(error))
|
||||
|
||||
def complete_finalize_task(
|
||||
self,
|
||||
*,
|
||||
condition: threading.Condition,
|
||||
job_registry: SchedulerJobRegistry,
|
||||
job: SchedulerPendingJob,
|
||||
item: T2SFinishedItem,
|
||||
sample_rate: int,
|
||||
audio_data: np.ndarray,
|
||||
) -> None:
|
||||
runtime_extra: Optional[Dict[str, Any]] = None
|
||||
with condition:
|
||||
if job_registry.get(item.request_id) is not job:
|
||||
return
|
||||
self.build_completed_job_result(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
runtime_extra = self.build_runtime_complete_payload(job, item, sample_rate=sample_rate)
|
||||
self.complete_job(
|
||||
job,
|
||||
runtime_request_id=job.engine_request_id,
|
||||
runtime_extra=runtime_extra,
|
||||
on_job_finished=lambda: job_registry.mark_finished_and_remove(item.request_id),
|
||||
notify_waiters=condition.notify_all,
|
||||
)
|
||||
|
||||
def fail_jobs(
|
||||
self,
|
||||
*,
|
||||
condition: threading.Condition,
|
||||
job_registry: SchedulerJobRegistry,
|
||||
request_ids: List[str],
|
||||
error: str,
|
||||
) -> None:
|
||||
if not request_ids:
|
||||
return
|
||||
with condition:
|
||||
for request_id in request_ids:
|
||||
job = job_registry.get(request_id)
|
||||
if job is None:
|
||||
continue
|
||||
self.fail_job(
|
||||
job,
|
||||
error=error,
|
||||
on_job_finished=lambda rid=request_id: job_registry.mark_finished_and_remove(rid),
|
||||
)
|
||||
condition.notify_all()
|
||||
430
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py
Normal file
430
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_decode.py
Normal file
@ -0,0 +1,430 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import (
|
||||
T2SActiveBatch,
|
||||
T2SFinishedItem,
|
||||
decode_one_step,
|
||||
merge_active_batches,
|
||||
run_prefill_active_batch,
|
||||
)
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import RuntimeStateCallbacks, SchedulerPendingJob
|
||||
|
||||
|
||||
class WorkerDecodeExecutor:
|
||||
def __init__(self, tts: TTS, max_steps: int) -> None:
|
||||
self.tts = tts
|
||||
self.max_steps = int(max_steps)
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
try:
|
||||
device_str = str(self.tts.configs.device)
|
||||
if device_str.startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize(self.tts.configs.device)
|
||||
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def execute_prefill_merge(
|
||||
self,
|
||||
*,
|
||||
pending_jobs: List[SchedulerPendingJob],
|
||||
active_batch: Optional[T2SActiveBatch],
|
||||
mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None],
|
||||
add_prefill_time: Callable[[List[str], float], None] | None,
|
||||
add_merge_time: Callable[[List[str], float], None] | None,
|
||||
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
|
||||
finalize_error: Callable[[List[str], str], None] | None,
|
||||
) -> Dict[str, Any]:
|
||||
if not pending_jobs:
|
||||
return {
|
||||
"executed": False,
|
||||
"active_batch": active_batch,
|
||||
"pending_jobs": [],
|
||||
"prefill_elapsed_s": 0.0,
|
||||
"merge_elapsed_s": 0.0,
|
||||
"finished_items": [],
|
||||
"error": None,
|
||||
"error_request_ids": [],
|
||||
}
|
||||
admitted_finished: List[T2SFinishedItem] = []
|
||||
prefill_elapsed_s = 0.0
|
||||
merge_elapsed_s = 0.0
|
||||
error: str | None = None
|
||||
error_request_ids: List[str] = []
|
||||
try:
|
||||
self._sync_device()
|
||||
prefill_start = time.perf_counter()
|
||||
mark_prefill_started(pending_jobs, prefill_start)
|
||||
admitted_active_batch, admitted_finished = run_prefill_active_batch(
|
||||
self.tts.t2s_model.model,
|
||||
[job.state for job in pending_jobs],
|
||||
max_steps=self.max_steps,
|
||||
)
|
||||
self._sync_device()
|
||||
prefill_elapsed_s = time.perf_counter() - prefill_start
|
||||
if add_prefill_time is not None:
|
||||
add_prefill_time([job.request_id for job in pending_jobs], prefill_elapsed_s)
|
||||
if enqueue_finished is not None:
|
||||
enqueue_finished(admitted_finished)
|
||||
merge_start = time.perf_counter()
|
||||
active_batch = merge_active_batches(
|
||||
self.tts.t2s_model.model,
|
||||
active_batch,
|
||||
admitted_active_batch,
|
||||
)
|
||||
merge_elapsed_s = time.perf_counter() - merge_start
|
||||
if add_merge_time is not None:
|
||||
add_merge_time(
|
||||
[] if active_batch is None else list(active_batch.request_ids),
|
||||
merge_elapsed_s,
|
||||
)
|
||||
except Exception as exc:
|
||||
error = str(exc)
|
||||
error_request_ids = [job.request_id for job in pending_jobs]
|
||||
if finalize_error is not None:
|
||||
finalize_error(error_request_ids, error)
|
||||
return {
|
||||
"executed": True,
|
||||
"active_batch": active_batch,
|
||||
"pending_jobs": list(pending_jobs),
|
||||
"prefill_elapsed_s": float(prefill_elapsed_s),
|
||||
"merge_elapsed_s": float(merge_elapsed_s),
|
||||
"finished_items": list(admitted_finished),
|
||||
"error": error,
|
||||
"error_request_ids": error_request_ids,
|
||||
}
|
||||
|
||||
def execute_decode_step(
|
||||
self,
|
||||
*,
|
||||
active_batch: Optional[T2SActiveBatch],
|
||||
add_decode_time: Callable[[List[str], float], None] | None,
|
||||
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
|
||||
finalize_error: Callable[[List[str], str], None] | None,
|
||||
) -> Dict[str, Any]:
|
||||
if active_batch is None:
|
||||
return {
|
||||
"executed": False,
|
||||
"active_batch": None,
|
||||
"request_ids": [],
|
||||
"decode_elapsed_s": 0.0,
|
||||
"finished_items": [],
|
||||
"error": None,
|
||||
"error_request_ids": [],
|
||||
}
|
||||
active_request_ids: List[str] = []
|
||||
step_finished: List[T2SFinishedItem] = []
|
||||
decode_elapsed_s = 0.0
|
||||
error: str | None = None
|
||||
error_request_ids: List[str] = []
|
||||
try:
|
||||
active_request_ids = [state.request_id for state in active_batch.states]
|
||||
self._sync_device()
|
||||
decode_start = time.perf_counter()
|
||||
active_batch, step_finished = decode_one_step(
|
||||
self.tts.t2s_model.model,
|
||||
active_batch,
|
||||
max_steps=self.max_steps,
|
||||
)
|
||||
self._sync_device()
|
||||
decode_elapsed_s = time.perf_counter() - decode_start
|
||||
if add_decode_time is not None:
|
||||
add_decode_time(active_request_ids, decode_elapsed_s)
|
||||
if enqueue_finished is not None:
|
||||
enqueue_finished(step_finished)
|
||||
except Exception as exc:
|
||||
error = str(exc)
|
||||
error_request_ids = list(active_request_ids)
|
||||
if finalize_error is not None:
|
||||
finalize_error(error_request_ids, error)
|
||||
active_batch = None
|
||||
return {
|
||||
"executed": True,
|
||||
"active_batch": active_batch,
|
||||
"request_ids": active_request_ids,
|
||||
"decode_elapsed_s": float(decode_elapsed_s),
|
||||
"finished_items": list(step_finished),
|
||||
"error": error,
|
||||
"error_request_ids": error_request_ids,
|
||||
}
|
||||
|
||||
def execute_decode_cycle(
|
||||
self,
|
||||
*,
|
||||
pending_jobs: List[SchedulerPendingJob],
|
||||
active_batch: Optional[T2SActiveBatch],
|
||||
mark_prefill_started: Callable[[List[SchedulerPendingJob], float], None],
|
||||
add_prefill_time: Callable[[List[str], float], None] | None,
|
||||
add_merge_time: Callable[[List[str], float], None] | None,
|
||||
add_decode_time: Callable[[List[str], float], None] | None,
|
||||
enqueue_finished: Callable[[List[T2SFinishedItem]], None] | None,
|
||||
finalize_error: Callable[[List[str], str], None] | None,
|
||||
) -> Dict[str, Any]:
|
||||
result = {
|
||||
"executed": False,
|
||||
"prefill_merge_executed": False,
|
||||
"decode_step_executed": False,
|
||||
"active_batch": active_batch,
|
||||
"prefill_phase": {},
|
||||
"decode_phase": {},
|
||||
}
|
||||
prefill_phase = self.execute_prefill_merge(
|
||||
pending_jobs=list(pending_jobs),
|
||||
active_batch=result["active_batch"],
|
||||
mark_prefill_started=mark_prefill_started,
|
||||
add_prefill_time=add_prefill_time,
|
||||
add_merge_time=add_merge_time,
|
||||
enqueue_finished=enqueue_finished,
|
||||
finalize_error=finalize_error,
|
||||
)
|
||||
prefill_executed = bool(prefill_phase.get("executed", False))
|
||||
result["prefill_phase"] = prefill_phase
|
||||
result["active_batch"] = prefill_phase.get("active_batch")
|
||||
if prefill_executed:
|
||||
result["executed"] = True
|
||||
result["prefill_merge_executed"] = True
|
||||
decode_phase = self.execute_decode_step(
|
||||
active_batch=result["active_batch"],
|
||||
add_decode_time=add_decode_time,
|
||||
enqueue_finished=enqueue_finished,
|
||||
finalize_error=finalize_error,
|
||||
)
|
||||
decode_executed = bool(decode_phase.get("executed", False))
|
||||
result["decode_phase"] = decode_phase
|
||||
result["active_batch"] = decode_phase.get("active_batch")
|
||||
if decode_executed:
|
||||
result["executed"] = True
|
||||
result["decode_step_executed"] = True
|
||||
return result
|
||||
|
||||
|
||||
class WorkerDecodeLegacyShell:
|
||||
def __init__(self, condition: threading.Condition, micro_batch_wait_s: float) -> None:
|
||||
self.condition = condition
|
||||
self.micro_batch_wait_s = float(micro_batch_wait_s)
|
||||
self.pending_jobs: List[SchedulerPendingJob] = []
|
||||
self.active_batch: T2SActiveBatch | None = None
|
||||
|
||||
@staticmethod
|
||||
def _summarize_active_batch(active_batch: T2SActiveBatch | None) -> Dict[str, Any] | None:
|
||||
if active_batch is None:
|
||||
return None
|
||||
return {
|
||||
"request_count": int(len(active_batch.request_ids)),
|
||||
"request_ids": list(active_batch.request_ids),
|
||||
"prefill_done": bool(active_batch.prefill_done),
|
||||
"decode_step_index_max": (
|
||||
int(active_batch.step_indices.max().item())
|
||||
if active_batch.step_indices is not None and active_batch.step_indices.numel() > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
def current_backlog_locked(self) -> int:
|
||||
running_requests = 0 if self.active_batch is None else len(self.active_batch.request_ids)
|
||||
return int(len(self.pending_jobs) + running_requests)
|
||||
|
||||
def enqueue_pending_job_locked(self, job: SchedulerPendingJob) -> None:
|
||||
self.pending_jobs.append(job)
|
||||
|
||||
def snapshot_locked(self) -> Dict[str, Any]:
|
||||
active_batch_summary = self._summarize_active_batch(self.active_batch)
|
||||
executor_local_pending_jobs = int(len(self.pending_jobs))
|
||||
executor_local_running_requests = 0 if self.active_batch is None else int(len(self.active_batch.request_ids))
|
||||
executor_local_has_work = bool(self.pending_jobs or self.active_batch is not None)
|
||||
return {
|
||||
"executor_local_pending_jobs": executor_local_pending_jobs,
|
||||
"executor_local_running_requests": executor_local_running_requests,
|
||||
"executor_local_has_work": executor_local_has_work,
|
||||
"executor_local_active_batch": active_batch_summary,
|
||||
}
|
||||
|
||||
def is_idle_locked(self) -> bool:
|
||||
return self.active_batch is None and not self.pending_jobs
|
||||
|
||||
def take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
with self.condition:
|
||||
if not self.pending_jobs and self.active_batch is None:
|
||||
self.condition.wait(timeout=self.micro_batch_wait_s)
|
||||
elif wait_for_batch and self.pending_jobs:
|
||||
self.condition.wait(timeout=self.micro_batch_wait_s)
|
||||
if not self.pending_jobs:
|
||||
return []
|
||||
pending = list(self.pending_jobs)
|
||||
self.pending_jobs.clear()
|
||||
return pending
|
||||
|
||||
def take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
with self.condition:
|
||||
if not self.pending_jobs:
|
||||
return []
|
||||
if wait_for_batch:
|
||||
oldest_enqueue_time = float(self.pending_jobs[0].enqueue_time)
|
||||
if (time.perf_counter() - oldest_enqueue_time) < self.micro_batch_wait_s:
|
||||
return []
|
||||
pending = list(self.pending_jobs)
|
||||
self.pending_jobs.clear()
|
||||
return pending
|
||||
|
||||
def has_decode_runtime_work(self) -> bool:
|
||||
with self.condition:
|
||||
return bool(self.pending_jobs or self.active_batch is not None)
|
||||
|
||||
def build_runtime_summary_locked(self, *, total_cycles: int, prefill_cycles: int, step_cycles: int, last_event: str) -> Dict[str, Any]:
|
||||
active_request_ids = [] if self.active_batch is None else list(self.active_batch.request_ids)
|
||||
decode_step_index_max = 0
|
||||
prefill_done = False
|
||||
if self.active_batch is not None:
|
||||
prefill_done = bool(self.active_batch.prefill_done)
|
||||
if self.active_batch.step_indices is not None and self.active_batch.step_indices.numel() > 0:
|
||||
decode_step_index_max = int(self.active_batch.step_indices.max().item())
|
||||
return {
|
||||
"pending_jobs": int(len(self.pending_jobs)),
|
||||
"active_request_count": int(len(active_request_ids)),
|
||||
"active_request_ids": active_request_ids[:32],
|
||||
"prefill_done": bool(prefill_done),
|
||||
"decode_step_index_max": int(decode_step_index_max),
|
||||
"total_cycles": int(total_cycles),
|
||||
"prefill_cycles": int(prefill_cycles),
|
||||
"step_cycles": int(step_cycles),
|
||||
"has_work": bool(self.pending_jobs or self.active_batch is not None),
|
||||
"last_event": str(last_event),
|
||||
"updated_at": float(time.perf_counter()),
|
||||
}
|
||||
|
||||
def run_prefill_merge_once_nonblocking(
|
||||
self,
|
||||
*,
|
||||
external_pending_jobs: Optional[List[SchedulerPendingJob]],
|
||||
external_active_batch: Optional[T2SActiveBatch],
|
||||
execute_prefill_merge: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
pending_jobs = (
|
||||
list(external_pending_jobs)
|
||||
if external_pending_jobs is not None
|
||||
else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None)
|
||||
)
|
||||
active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch
|
||||
result = execute_prefill_merge(pending_jobs, active_batch)
|
||||
if external_pending_jobs is None:
|
||||
with self.condition:
|
||||
self.active_batch = result.get("active_batch")
|
||||
self.condition.notify_all()
|
||||
return result
|
||||
|
||||
def run_decode_step_once_nonblocking(
|
||||
self,
|
||||
*,
|
||||
external_active_batch: Optional[T2SActiveBatch],
|
||||
execute_decode_step: Callable[[Optional[T2SActiveBatch]], Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
active_batch = self.active_batch if external_active_batch is None else external_active_batch
|
||||
result = execute_decode_step(active_batch)
|
||||
if external_active_batch is None:
|
||||
with self.condition:
|
||||
self.active_batch = result.get("active_batch")
|
||||
self.condition.notify_all()
|
||||
return result
|
||||
|
||||
def run_decode_cycle_nonblocking(
|
||||
self,
|
||||
*,
|
||||
external_pending_jobs: Optional[List[SchedulerPendingJob]],
|
||||
external_active_batch: Optional[T2SActiveBatch],
|
||||
execute_decode_cycle: Callable[[List[SchedulerPendingJob], Optional[T2SActiveBatch]], Dict[str, Any]],
|
||||
on_cycle_executed: Callable[[Dict[str, Any]], None] | None,
|
||||
) -> Dict[str, Any]:
|
||||
pending_jobs = (
|
||||
list(external_pending_jobs)
|
||||
if external_pending_jobs is not None
|
||||
else self.take_pending_snapshot_nonblocking(wait_for_batch=self.active_batch is None)
|
||||
)
|
||||
active_batch = external_active_batch if external_pending_jobs is not None else self.active_batch
|
||||
result = execute_decode_cycle(pending_jobs, active_batch)
|
||||
if external_pending_jobs is None:
|
||||
with self.condition:
|
||||
self.active_batch = result.get("active_batch")
|
||||
self.condition.notify_all()
|
||||
if result.get("executed") and on_cycle_executed is not None:
|
||||
on_cycle_executed(result)
|
||||
return result
|
||||
|
||||
def run_loop(
|
||||
self,
|
||||
*,
|
||||
run_decode_cycle_nonblocking: Callable[[], Dict[str, Any]],
|
||||
) -> None:
|
||||
while True:
|
||||
executed = run_decode_cycle_nonblocking()
|
||||
if executed.get("executed"):
|
||||
continue
|
||||
wait_for_batch = self.active_batch is None
|
||||
pending_jobs = self.take_pending_snapshot(wait_for_batch=wait_for_batch)
|
||||
if pending_jobs:
|
||||
with self.condition:
|
||||
self.pending_jobs = pending_jobs + self.pending_jobs
|
||||
self.condition.notify_all()
|
||||
continue
|
||||
time.sleep(self.micro_batch_wait_s)
|
||||
|
||||
|
||||
class WorkerDecodeRuntimeTracker:
|
||||
def __init__(
|
||||
self,
|
||||
runtime_callbacks: RuntimeStateCallbacks | None = None,
|
||||
) -> None:
|
||||
self.runtime_callbacks = runtime_callbacks or RuntimeStateCallbacks()
|
||||
self.total_cycles = 0
|
||||
self.prefill_cycles = 0
|
||||
self.step_cycles = 0
|
||||
|
||||
def get_counters(self) -> Dict[str, int]:
|
||||
return {
|
||||
"total_cycles": int(self.total_cycles),
|
||||
"prefill_cycles": int(self.prefill_cycles),
|
||||
"step_cycles": int(self.step_cycles),
|
||||
}
|
||||
|
||||
def record_cycle(self, result: Dict[str, Any]) -> None:
|
||||
if not bool(result.get("executed")):
|
||||
return
|
||||
self.total_cycles += 1
|
||||
if bool(result.get("prefill_merge_executed")):
|
||||
self.prefill_cycles += 1
|
||||
if bool(result.get("decode_step_executed")):
|
||||
self.step_cycles += 1
|
||||
|
||||
def build_runtime_summary_locked(
|
||||
self,
|
||||
*,
|
||||
legacy_shell: WorkerDecodeLegacyShell,
|
||||
last_event: str,
|
||||
) -> Dict[str, Any]:
|
||||
return legacy_shell.build_runtime_summary_locked(
|
||||
total_cycles=int(self.total_cycles),
|
||||
prefill_cycles=int(self.prefill_cycles),
|
||||
step_cycles=int(self.step_cycles),
|
||||
last_event=str(last_event),
|
||||
)
|
||||
|
||||
def notify_runtime_update_locked(
|
||||
self,
|
||||
*,
|
||||
legacy_shell: WorkerDecodeLegacyShell,
|
||||
last_event: str,
|
||||
) -> None:
|
||||
if self.runtime_callbacks.decode_runtime_update is None:
|
||||
return
|
||||
snapshot = self.build_runtime_summary_locked(
|
||||
legacy_shell=legacy_shell,
|
||||
last_event=last_event,
|
||||
)
|
||||
self.runtime_callbacks.decode_runtime_update(snapshot)
|
||||
164
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py
Normal file
164
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_execution.py
Normal file
@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SActiveBatch, T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class WorkerExecutionMixin:
|
||||
def execute_prefill_merge(
|
||||
self,
|
||||
pending_jobs: List[SchedulerPendingJob],
|
||||
active_batch: Optional[T2SActiveBatch],
|
||||
external_bookkeeping: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
return self.decode_executor.execute_prefill_merge(
|
||||
pending_jobs=pending_jobs,
|
||||
active_batch=active_batch,
|
||||
mark_prefill_started=self._mark_prefill_started,
|
||||
add_prefill_time=None if external_bookkeeping else self._add_prefill_time,
|
||||
add_merge_time=None if external_bookkeeping else self._add_merge_time,
|
||||
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
|
||||
finalize_error=None if external_bookkeeping else self._finalize_error,
|
||||
)
|
||||
|
||||
def execute_decode_step(
|
||||
self,
|
||||
active_batch: Optional[T2SActiveBatch],
|
||||
external_bookkeeping: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
return self.decode_executor.execute_decode_step(
|
||||
active_batch=active_batch,
|
||||
add_decode_time=None if external_bookkeeping else self._add_decode_time,
|
||||
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
|
||||
finalize_error=None if external_bookkeeping else self._finalize_error,
|
||||
)
|
||||
|
||||
def execute_decode_cycle(
|
||||
self,
|
||||
pending_jobs: List[SchedulerPendingJob],
|
||||
active_batch: Optional[T2SActiveBatch],
|
||||
external_bookkeeping: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
result = self.decode_executor.execute_decode_cycle(
|
||||
pending_jobs=pending_jobs,
|
||||
active_batch=active_batch,
|
||||
mark_prefill_started=self._mark_prefill_started,
|
||||
add_prefill_time=None if external_bookkeeping else self._add_prefill_time,
|
||||
add_merge_time=None if external_bookkeeping else self._add_merge_time,
|
||||
add_decode_time=None if external_bookkeeping else self._add_decode_time,
|
||||
enqueue_finished=None if external_bookkeeping else self._enqueue_finalize_finished,
|
||||
finalize_error=None if external_bookkeeping else self._finalize_error,
|
||||
)
|
||||
self._record_decode_runtime_cycle(result)
|
||||
return result
|
||||
|
||||
def run_prefill_merge_once_nonblocking(
|
||||
self,
|
||||
external_pending_jobs: Optional[List[SchedulerPendingJob]] = None,
|
||||
external_active_batch: Optional[T2SActiveBatch] = None,
|
||||
emit_runtime_state: bool = True,
|
||||
external_bookkeeping: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
result = self.decode_legacy_shell.run_prefill_merge_once_nonblocking(
|
||||
external_pending_jobs=external_pending_jobs,
|
||||
external_active_batch=external_active_batch,
|
||||
execute_prefill_merge=lambda batch_jobs, batch_state: self.execute_prefill_merge(
|
||||
pending_jobs=batch_jobs,
|
||||
active_batch=batch_state,
|
||||
external_bookkeeping=external_bookkeeping,
|
||||
),
|
||||
)
|
||||
if emit_runtime_state:
|
||||
self._notify_decode_runtime_state("prefill_merge")
|
||||
return result
|
||||
|
||||
def run_decode_step_once_nonblocking(
|
||||
self,
|
||||
external_active_batch: Optional[T2SActiveBatch] = None,
|
||||
emit_runtime_state: bool = True,
|
||||
external_bookkeeping: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
result = self.decode_legacy_shell.run_decode_step_once_nonblocking(
|
||||
external_active_batch=external_active_batch,
|
||||
execute_decode_step=lambda batch_state: self.execute_decode_step(
|
||||
active_batch=batch_state,
|
||||
external_bookkeeping=external_bookkeeping,
|
||||
),
|
||||
)
|
||||
if emit_runtime_state:
|
||||
self._notify_decode_runtime_state("decode_step")
|
||||
return result
|
||||
|
||||
def run_decode_cycle_nonblocking(
|
||||
self,
|
||||
external_pending_jobs: Optional[List[SchedulerPendingJob]] = None,
|
||||
external_active_batch: Optional[T2SActiveBatch] = None,
|
||||
emit_runtime_state: bool = True,
|
||||
external_bookkeeping: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
result = self.decode_legacy_shell.run_decode_cycle_nonblocking(
|
||||
external_pending_jobs=external_pending_jobs,
|
||||
external_active_batch=external_active_batch,
|
||||
execute_decode_cycle=lambda batch_jobs, batch_state: self.execute_decode_cycle(
|
||||
pending_jobs=batch_jobs,
|
||||
active_batch=batch_state,
|
||||
external_bookkeeping=external_bookkeeping,
|
||||
),
|
||||
on_cycle_executed=None,
|
||||
)
|
||||
if result.get("executed") and emit_runtime_state:
|
||||
self._notify_decode_runtime_state("decode_cycle")
|
||||
return result
|
||||
|
||||
def execute_finalize_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
try:
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]] = []
|
||||
with self.condition:
|
||||
for task in tasks:
|
||||
job = self.job_registry.get(task.request_id)
|
||||
if job is None:
|
||||
continue
|
||||
jobs_and_items.append((job, task.item))
|
||||
if not jobs_and_items:
|
||||
return
|
||||
now = time.perf_counter()
|
||||
for task in tasks:
|
||||
self._add_finalize_wait_ms([task.request_id], max(0.0, (now - task.enqueued_time) * 1000.0))
|
||||
for job, item in jobs_and_items:
|
||||
self._runtime_update(
|
||||
job.engine_request_id,
|
||||
EngineStatus.FINALIZING,
|
||||
{
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
},
|
||||
)
|
||||
synth_ms, batch_results = self.synthesize_finalize_jobs(jobs_and_items)
|
||||
with self.condition:
|
||||
for job, _ in jobs_and_items:
|
||||
tracked_job = self.job_registry.get(job.request_id)
|
||||
if tracked_job is not None:
|
||||
tracked_job.synth_ms += synth_ms
|
||||
for (job, item), (sample_rate, audio_data) in zip(jobs_and_items, batch_results):
|
||||
self._complete_finalize_task(job, item, sample_rate=sample_rate, audio_data=audio_data)
|
||||
except Exception as exc:
|
||||
self._finalize_error([task.request_id for task in tasks], str(exc))
|
||||
finally:
|
||||
self.finalize_executor.end_execution(len(tasks))
|
||||
|
||||
def _run_finalize_loop(self) -> None:
|
||||
while True:
|
||||
tasks = self.finalize_executor.take_task_batch_blocking()
|
||||
self.execute_finalize_tasks(tasks)
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
self.decode_legacy_shell.run_loop(
|
||||
run_decode_cycle_nonblocking=lambda: self.run_decode_cycle_nonblocking()
|
||||
)
|
||||
251
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py
Normal file
251
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_finalize.py
Normal file
@ -0,0 +1,251 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Deque, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class WorkerFinalizeExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
tts: TTS,
|
||||
on_state_change: Callable[[], None] | None = None,
|
||||
external_submit: Callable[[List[SchedulerFinalizeTask]], None] | None = None,
|
||||
) -> None:
|
||||
self.tts = tts
|
||||
self.on_state_change = on_state_change
|
||||
self.external_submit = external_submit
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[SchedulerFinalizeTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.inflight = 0
|
||||
self.inflight_peak = 0
|
||||
self.worker_count = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_WORKERS", 1)))
|
||||
self.finalize_mode = os.environ.get("GPTSOVITS_FINALIZE_MODE", "async").strip().lower()
|
||||
self.batch_max_items = max(1, int(os.environ.get("GPTSOVITS_FINALIZE_BATCH_MAX_ITEMS", 16)))
|
||||
self.batch_wait_s = max(0.0, float(os.environ.get("GPTSOVITS_FINALIZE_BATCH_WAIT_MS", "2")) / 1000.0)
|
||||
|
||||
def _notify_state_change(self) -> None:
|
||||
if self.on_state_change is None:
|
||||
return
|
||||
try:
|
||||
self.on_state_change()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
return int(self.worker_count)
|
||||
|
||||
def get_batch_policy(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"finalize_mode": str(self.finalize_mode),
|
||||
"finalize_batch_max_items": int(self.batch_max_items),
|
||||
"finalize_batch_wait_s": float(self.batch_wait_s),
|
||||
}
|
||||
|
||||
def get_pending_count(self) -> int:
|
||||
with self.condition:
|
||||
return int(len(self.pending_tasks))
|
||||
|
||||
def snapshot(self) -> Dict[str, Any]:
|
||||
with self.condition:
|
||||
return {
|
||||
"finalize_pending": int(len(self.pending_tasks)),
|
||||
"finalize_pending_peak": int(self.pending_peak),
|
||||
"finalize_inflight": int(self.inflight),
|
||||
"finalize_inflight_peak": int(self.inflight_peak),
|
||||
"finalize_workers": int(self.worker_count),
|
||||
"finalize_mode": str(self.finalize_mode),
|
||||
"finalize_batch_max_items": int(self.batch_max_items),
|
||||
"finalize_batch_wait_ms": float(self.batch_wait_s * 1000.0),
|
||||
}
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
with self.condition:
|
||||
return self.inflight <= 0 and not self.pending_tasks
|
||||
|
||||
def enqueue_tasks(self, tasks: List[SchedulerFinalizeTask]) -> None:
|
||||
if not tasks:
|
||||
return
|
||||
if self.external_submit is not None:
|
||||
self.external_submit(tasks)
|
||||
self._notify_state_change()
|
||||
return
|
||||
with self.condition:
|
||||
for task in tasks:
|
||||
self.pending_tasks.append(task)
|
||||
self.pending_peak = max(self.pending_peak, len(self.pending_tasks))
|
||||
self.condition.notify_all()
|
||||
self._notify_state_change()
|
||||
|
||||
def begin_execution(self, task_count: int) -> None:
|
||||
if task_count <= 0:
|
||||
return
|
||||
with self.condition:
|
||||
self.inflight += int(task_count)
|
||||
self.inflight_peak = max(self.inflight_peak, self.inflight)
|
||||
self.condition.notify_all()
|
||||
self._notify_state_change()
|
||||
|
||||
def end_execution(self, task_count: int) -> None:
|
||||
with self.condition:
|
||||
self.inflight = max(0, self.inflight - int(task_count))
|
||||
self.condition.notify_all()
|
||||
self._notify_state_change()
|
||||
|
||||
def take_task_batch_blocking(self) -> List[SchedulerFinalizeTask]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
selected_tasks = [self.pending_tasks.popleft()]
|
||||
if self.finalize_mode == "sync" or self.tts.configs.use_vocoder:
|
||||
self.inflight += len(selected_tasks)
|
||||
self.inflight_peak = max(self.inflight_peak, self.inflight)
|
||||
self._notify_state_change()
|
||||
return selected_tasks
|
||||
batch_deadline = time.perf_counter() + self.batch_wait_s
|
||||
while len(selected_tasks) < self.batch_max_items:
|
||||
if not self.pending_tasks:
|
||||
remaining = batch_deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
self.condition.wait(timeout=remaining)
|
||||
continue
|
||||
first_task = selected_tasks[0]
|
||||
matched_index = None
|
||||
for index, task in enumerate(self.pending_tasks):
|
||||
if abs(task.enqueued_time - first_task.enqueued_time) < 1.0:
|
||||
matched_index = index
|
||||
break
|
||||
if matched_index is not None:
|
||||
selected_tasks.append(self.pending_tasks[matched_index])
|
||||
del self.pending_tasks[matched_index]
|
||||
continue
|
||||
remaining = batch_deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
self.condition.wait(timeout=remaining)
|
||||
self.inflight += len(selected_tasks)
|
||||
self.inflight_peak = max(self.inflight_peak, self.inflight)
|
||||
self._notify_state_change()
|
||||
return selected_tasks
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
try:
|
||||
device_str = str(self.tts.configs.device)
|
||||
if device_str.startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize(self.tts.configs.device)
|
||||
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _collect_job_refer_specs(job: SchedulerPendingJob) -> List[tuple]:
|
||||
refer_specs = []
|
||||
if job.state.refer_spec is not None:
|
||||
refer_specs.append(job.state.refer_spec)
|
||||
refer_specs.extend(list(getattr(job.state, "aux_refer_specs", []) or []))
|
||||
return refer_specs
|
||||
|
||||
def _synthesize_finished_audio(self, job: SchedulerPendingJob, item: T2SFinishedItem) -> tuple[int, np.ndarray]:
|
||||
audio_fragment = self.tts.synthesize_audio_request_local(
|
||||
semantic_tokens=item.semantic_tokens.detach().clone().unsqueeze(0).unsqueeze(0),
|
||||
phones=job.state.phones.detach().clone().unsqueeze(0),
|
||||
prompt_semantic=job.state.prompt_semantic.detach().clone(),
|
||||
prompt_phones=job.state.prompt_phones.detach().clone(),
|
||||
refer_spec=[
|
||||
(
|
||||
refer_spec_item[0].detach().clone(),
|
||||
None if refer_spec_item[1] is None else refer_spec_item[1].detach().clone(),
|
||||
)
|
||||
for refer_spec_item in self._collect_job_refer_specs(job)
|
||||
],
|
||||
raw_audio=job.state.raw_audio.detach().clone(),
|
||||
raw_sr=int(job.state.raw_sr),
|
||||
speed=float(job.speed_factor),
|
||||
sample_steps=int(job.sample_steps),
|
||||
)
|
||||
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
|
||||
return self.tts.audio_postprocess(
|
||||
audio=[[audio_fragment]],
|
||||
sr=int(output_sr),
|
||||
batch_index_list=None,
|
||||
speed_factor=float(job.speed_factor),
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=bool(job.super_sampling),
|
||||
)
|
||||
|
||||
def _synthesize_finished_audio_batch(
|
||||
self,
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
|
||||
) -> List[tuple[int, np.ndarray]]:
|
||||
semantic_tokens_list = [item.semantic_tokens.detach().clone() for _, item in jobs_and_items]
|
||||
phones_list = [job.state.phones.detach().clone() for job, _ in jobs_and_items]
|
||||
refer_specs = []
|
||||
speeds = []
|
||||
sample_steps_list = []
|
||||
for job, _ in jobs_and_items:
|
||||
refer_spec_group = self._collect_job_refer_specs(job)
|
||||
if len(refer_spec_group) != 1:
|
||||
raise ValueError("batched finalize 暂不支持单请求多参考音频")
|
||||
refer_specs.append(
|
||||
[(
|
||||
refer_spec_group[0][0].detach().clone(),
|
||||
None if refer_spec_group[0][1] is None else refer_spec_group[0][1].detach().clone(),
|
||||
)]
|
||||
)
|
||||
speeds.append(float(job.speed_factor))
|
||||
sample_steps_list.append(int(job.sample_steps))
|
||||
audio_fragments = self.tts.synthesize_audio_requests_local_batched(
|
||||
semantic_tokens_list=semantic_tokens_list,
|
||||
phones_list=phones_list,
|
||||
refer_specs=refer_specs,
|
||||
speeds=speeds,
|
||||
sample_steps_list=sample_steps_list,
|
||||
)
|
||||
output_sr = self.tts.configs.sampling_rate if not self.tts.configs.use_vocoder else self.tts.vocoder_configs["sr"]
|
||||
results: List[tuple[int, np.ndarray]] = []
|
||||
for (job, _), audio_fragment in zip(jobs_and_items, audio_fragments):
|
||||
results.append(
|
||||
self.tts.audio_postprocess(
|
||||
audio=[[audio_fragment]],
|
||||
sr=int(output_sr),
|
||||
batch_index_list=None,
|
||||
speed_factor=float(job.speed_factor),
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=bool(job.super_sampling),
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def synthesize_finalize_jobs(
|
||||
self,
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
|
||||
) -> tuple[float, List[tuple[int, np.ndarray]]]:
|
||||
if not jobs_and_items:
|
||||
return 0.0, []
|
||||
self._sync_device()
|
||||
synth_start = time.perf_counter()
|
||||
if (
|
||||
len(jobs_and_items) == 1
|
||||
or self.tts.configs.use_vocoder
|
||||
or any(len(self._collect_job_refer_specs(job)) != 1 for job, _ in jobs_and_items)
|
||||
):
|
||||
batch_results = [self._synthesize_finished_audio(job, item) for job, item in jobs_and_items]
|
||||
else:
|
||||
batch_results = self._synthesize_finished_audio_batch(jobs_and_items)
|
||||
self._sync_device()
|
||||
synth_ms = (time.perf_counter() - synth_start) * 1000.0
|
||||
return float(synth_ms), batch_results
|
||||
140
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py
Normal file
140
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_prepare.py
Normal file
@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS
|
||||
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PrepareCoordinator, PreparedCpuStage
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
|
||||
|
||||
|
||||
class WorkerPrepareExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
tts: TTS,
|
||||
on_state_change: Callable[[], None] | None = None,
|
||||
) -> None:
|
||||
self.coordinator = PrepareCoordinator(tts)
|
||||
self.on_state_change = on_state_change
|
||||
|
||||
def _notify_state_change(self) -> None:
|
||||
if self.on_state_change is None:
|
||||
return
|
||||
try:
|
||||
self.on_state_change()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
return dict(self.coordinator.snapshot())
|
||||
|
||||
def get_max_inflight(self) -> int:
|
||||
return int(self.coordinator.snapshot().get("max_inflight", 0))
|
||||
|
||||
def get_batch_policy(self) -> Dict[str, int]:
|
||||
return {
|
||||
"prepare_batch_max_items": max(1, int(os.environ.get("GPTSOVITS_ENGINE_PREPARE_BATCH_MAX_ITEMS", 8))),
|
||||
}
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return int(self.coordinator.snapshot().get("inflight", 0)) <= 0
|
||||
|
||||
async def prepare_state_profiled_async(
|
||||
self,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
try:
|
||||
return await self.coordinator.prepare_state_profiled_async(spec, prepare_submit_at)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
|
||||
results = await asyncio.gather(
|
||||
*[self.prepare_state_profiled_async(spec, time.perf_counter()) for spec in specs]
|
||||
)
|
||||
return [state for state, _, _ in results]
|
||||
|
||||
async def prepare_cpu_stage_profiled_async(
|
||||
self,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
) -> PreparedCpuStage:
|
||||
try:
|
||||
return await self.coordinator.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_gpu_stage_profiled_async(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
try:
|
||||
return await self.coordinator.prepare_gpu_stage_profiled_async(cpu_stage)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_gpu_stages_profiled_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[tuple[T2SRequestState, float, float] | Exception]:
|
||||
try:
|
||||
return await self.coordinator.prepare_gpu_stages_profiled_async(cpu_stages)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_gpu_audio_phases_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
try:
|
||||
return await self.coordinator.prepare_gpu_audio_phases_async(cpu_stages)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_gpu_text_phases_async(
|
||||
self,
|
||||
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
try:
|
||||
return await self.coordinator.prepare_gpu_text_phases_async(items)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
def build_gpu_prepare_result_from_phases(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
phase_one: Dict[str, Any],
|
||||
phase_two: Dict[str, Any],
|
||||
extra_profile: Dict[str, float] | None = None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
try:
|
||||
return self.coordinator.build_gpu_prepare_result_from_phases(
|
||||
cpu_stage,
|
||||
phase_one,
|
||||
phase_two,
|
||||
extra_profile=extra_profile,
|
||||
)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
async def prepare_ref_spec_stages_async(
|
||||
self,
|
||||
phase_ones: List[Dict[str, Any]],
|
||||
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
|
||||
try:
|
||||
return await self.coordinator.prepare_ref_spec_stages_async(phase_ones)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
|
||||
def apply_ref_spec_result_to_state(
|
||||
self,
|
||||
state: T2SRequestState,
|
||||
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
|
||||
) -> None:
|
||||
try:
|
||||
self.coordinator.apply_ref_spec_result_to_state(state, ref_spec_result)
|
||||
finally:
|
||||
self._notify_state_change()
|
||||
170
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py
Normal file
170
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_runtime.py
Normal file
@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import T2SFinishedItem
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerFinalizeTask, SchedulerPendingJob
|
||||
|
||||
|
||||
class WorkerRuntimeBookkeepingMixin:
|
||||
def _mark_prefill_started(self, pending_jobs: List[SchedulerPendingJob], started_at: float) -> None:
|
||||
with self.condition:
|
||||
for job in pending_jobs:
|
||||
job.first_schedule_time = float(started_at)
|
||||
self._runtime_update(
|
||||
job.engine_request_id,
|
||||
EngineStatus.GPU_PREPARING,
|
||||
{"scheduler_request_id": job.request_id, "prefill_started_at": float(started_at)},
|
||||
)
|
||||
|
||||
def _add_prefill_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
if not request_ids:
|
||||
return
|
||||
with self.condition:
|
||||
for request_id in request_ids:
|
||||
job = self.job_registry.get(request_id)
|
||||
if job is not None:
|
||||
job.prefill_ms += delta_ms
|
||||
|
||||
def _add_merge_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
if not request_ids:
|
||||
return
|
||||
with self.condition:
|
||||
for request_id in request_ids:
|
||||
job = self.job_registry.get(request_id)
|
||||
if job is not None:
|
||||
job.merge_ms += delta_ms
|
||||
|
||||
def _add_decode_time(self, request_ids: List[str], elapsed_s: float) -> None:
|
||||
delta_ms = float(elapsed_s) * 1000.0
|
||||
if not request_ids:
|
||||
return
|
||||
activate_request_ids: List[str] = []
|
||||
with self.condition:
|
||||
for request_id in request_ids:
|
||||
job = self.job_registry.get(request_id)
|
||||
if job is not None:
|
||||
if job.decode_steps == 0:
|
||||
activate_request_ids.append(job.engine_request_id)
|
||||
job.decode_ms += delta_ms
|
||||
job.decode_steps += 1
|
||||
for engine_request_id in activate_request_ids:
|
||||
self._runtime_update(engine_request_id, EngineStatus.ACTIVE_DECODE, None)
|
||||
|
||||
def _add_finalize_wait_ms(self, request_ids: List[str], delta_ms: float) -> None:
|
||||
if not request_ids:
|
||||
return
|
||||
with self.condition:
|
||||
for request_id in request_ids:
|
||||
job = self.job_registry.get(request_id)
|
||||
if job is not None:
|
||||
job.finalize_wait_ms += float(delta_ms)
|
||||
|
||||
def _enqueue_finalize_finished(self, items: List[T2SFinishedItem]) -> None:
|
||||
if not items:
|
||||
return
|
||||
enqueued_at = time.perf_counter()
|
||||
tasks: List[SchedulerFinalizeTask] = []
|
||||
with self.condition:
|
||||
for item in items:
|
||||
job = self.job_registry.get(item.request_id)
|
||||
if job is not None:
|
||||
self._runtime_update(
|
||||
job.engine_request_id,
|
||||
EngineStatus.READY_FOR_FINALIZE,
|
||||
{
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
},
|
||||
)
|
||||
tasks.append(SchedulerFinalizeTask(request_id=item.request_id, item=item, enqueued_time=enqueued_at))
|
||||
self.finalize_executor.enqueue_tasks(tasks)
|
||||
|
||||
def begin_finalize_execution(self, task_count: int) -> None:
|
||||
self.finalize_executor.begin_execution(task_count)
|
||||
|
||||
def end_finalize_execution(self, task_count: int) -> None:
|
||||
self.finalize_executor.end_execution(task_count)
|
||||
|
||||
def record_external_job_done(self, request_id: str) -> None:
|
||||
with self.condition:
|
||||
self.job_registry.mark_finished_and_remove(request_id)
|
||||
self.condition.notify_all()
|
||||
|
||||
def synthesize_finalize_jobs(
|
||||
self,
|
||||
jobs_and_items: List[tuple[SchedulerPendingJob, T2SFinishedItem]],
|
||||
) -> tuple[float, List[tuple[int, np.ndarray]]]:
|
||||
return self.finalize_executor.synthesize_finalize_jobs(jobs_and_items)
|
||||
|
||||
def _complete_finalize_task(self, job: SchedulerPendingJob, item: T2SFinishedItem, sample_rate: int, audio_data: np.ndarray) -> None:
|
||||
self.completion_bridge.complete_finalize_task(
|
||||
condition=self.condition,
|
||||
job_registry=self.job_registry,
|
||||
job=job,
|
||||
item=item,
|
||||
sample_rate=sample_rate,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
|
||||
def _finalize_error(self, request_ids: List[str], error: str) -> None:
|
||||
self.completion_bridge.fail_jobs(
|
||||
condition=self.condition,
|
||||
job_registry=self.job_registry,
|
||||
request_ids=request_ids,
|
||||
error=error,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_future(job: SchedulerPendingJob) -> None:
|
||||
future = job.done_future
|
||||
if future is None or future.done():
|
||||
return
|
||||
future.set_result(job)
|
||||
|
||||
def _notify_done_future(self, job: SchedulerPendingJob) -> None:
|
||||
self.completion_bridge.notify_done_future(job)
|
||||
|
||||
def _runtime_update(self, request_id: str | None, status: str, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
if request_id is None or self.runtime_callbacks.update is None:
|
||||
return
|
||||
self.runtime_callbacks.update(request_id, status, extra)
|
||||
|
||||
def _runtime_complete(self, request_id: str | None, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.completion_bridge.runtime_complete(request_id, extra)
|
||||
|
||||
def _runtime_fail(self, request_id: str | None, error: str) -> None:
|
||||
self.completion_bridge.runtime_fail(request_id, error)
|
||||
|
||||
def _build_decode_runtime_summary_locked(self, last_event: str) -> Dict[str, Any]:
|
||||
return self.decode_runtime_tracker.build_runtime_summary_locked(
|
||||
legacy_shell=self.decode_legacy_shell,
|
||||
last_event=str(last_event),
|
||||
)
|
||||
|
||||
def _notify_decode_runtime_state(self, last_event: str) -> None:
|
||||
with self.condition:
|
||||
self.decode_runtime_tracker.notify_runtime_update_locked(
|
||||
legacy_shell=self.decode_legacy_shell,
|
||||
last_event=str(last_event),
|
||||
)
|
||||
|
||||
def _record_decode_runtime_cycle(self, result: Dict[str, Any]) -> None:
|
||||
with self.condition:
|
||||
self.decode_runtime_tracker.record_cycle(result)
|
||||
|
||||
def _take_pending_snapshot(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
return self.decode_legacy_shell.take_pending_snapshot(wait_for_batch)
|
||||
|
||||
def _take_pending_snapshot_nonblocking(self, wait_for_batch: bool) -> List[SchedulerPendingJob]:
|
||||
return self.decode_legacy_shell.take_pending_snapshot_nonblocking(wait_for_batch)
|
||||
|
||||
def has_decode_runtime_work(self) -> bool:
|
||||
return self.decode_legacy_shell.has_decode_runtime_work()
|
||||
308
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py
Normal file
308
GPT_SoVITS/TTS_infer_pack/unified_engine_worker_submit.py
Normal file
@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.prepare_coordinator import PreparedCpuStage
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import SchedulerRequestSpec, T2SRequestState
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine_components import EngineStatus, SchedulerPendingJob
|
||||
|
||||
|
||||
class WorkerSubmitLifecycleMixin:
|
||||
def _current_decode_backlog_locked(self) -> int:
|
||||
return self.decode_legacy_shell.current_backlog_locked()
|
||||
|
||||
def get_micro_batch_wait_s(self) -> float:
|
||||
return float(self.micro_batch_wait_s)
|
||||
|
||||
def is_engine_decode_control_enabled(self) -> bool:
|
||||
return bool(self.engine_decode_control_enabled)
|
||||
|
||||
def get_prepare_max_inflight(self) -> int:
|
||||
return int(self.prepare_executor.get_max_inflight())
|
||||
|
||||
def get_capacity_limits(self) -> Dict[str, int]:
|
||||
return {
|
||||
"decode_backlog_max": int(self.decode_backlog_max),
|
||||
"finalize_pending_max": int(self.finalize_pending_max),
|
||||
}
|
||||
|
||||
def get_finalize_batch_policy(self) -> Dict[str, Any]:
|
||||
return dict(self.finalize_executor.get_batch_policy())
|
||||
|
||||
def get_prepare_batch_policy(self) -> Dict[str, int]:
|
||||
return dict(self.prepare_executor.get_batch_policy())
|
||||
|
||||
def get_decode_runtime_counters(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return self.decode_runtime_tracker.get_counters()
|
||||
|
||||
def _can_accept_submit_locked(self) -> tuple[bool, Dict[str, int]]:
|
||||
decode_backlog = self._current_decode_backlog_locked()
|
||||
finalize_pending = int(self.finalize_executor.get_pending_count())
|
||||
prepare_inflight = int(self.prepare_executor.snapshot()["inflight"])
|
||||
blocked_decode = self.decode_backlog_max > 0 and decode_backlog >= self.decode_backlog_max
|
||||
blocked_finalize = self.finalize_pending_max > 0 and finalize_pending >= self.finalize_pending_max
|
||||
return (
|
||||
not blocked_decode and not blocked_finalize,
|
||||
{
|
||||
"decode_backlog": decode_backlog,
|
||||
"finalize_pending": finalize_pending,
|
||||
"prepare_inflight": prepare_inflight,
|
||||
"decode_backlog_max": int(self.decode_backlog_max),
|
||||
"finalize_pending_max": int(self.finalize_pending_max),
|
||||
},
|
||||
)
|
||||
|
||||
def wait_for_submit_capacity_blocking(self, timeout_sec: float | None = None) -> tuple[float, Dict[str, int]]:
|
||||
start = time.perf_counter()
|
||||
deadline = None if timeout_sec in [None, ""] else (start + max(0.0, float(timeout_sec)))
|
||||
while True:
|
||||
with self.condition:
|
||||
allowed, snapshot = self._can_accept_submit_locked()
|
||||
if allowed:
|
||||
return max(0.0, (time.perf_counter() - start) * 1000.0), snapshot
|
||||
if deadline is not None and time.perf_counter() >= deadline:
|
||||
raise TimeoutError(
|
||||
"scheduler submit admission timeout "
|
||||
f"(decode_backlog={snapshot['decode_backlog']}, finalize_pending={snapshot['finalize_pending']})"
|
||||
)
|
||||
self.condition.wait(timeout=self.micro_batch_wait_s)
|
||||
|
||||
def _admission_snapshot_locked(self) -> Dict[str, int]:
|
||||
_, snapshot = self._can_accept_submit_locked()
|
||||
return snapshot
|
||||
|
||||
async def submit_async(
|
||||
self,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None = None,
|
||||
done_future: asyncio.Future | None = None,
|
||||
engine_request_id: str | None = None,
|
||||
timeout_sec: float | None = None,
|
||||
skip_capacity_wait: bool = False,
|
||||
admission_wait_ms_override: float | None = None,
|
||||
admission_snapshot_override: Dict[str, Any] | None = None,
|
||||
engine_policy_wait_ms: float = 0.0,
|
||||
engine_dispatch_wait_ms: float = 0.0,
|
||||
enqueue_pending: bool = True,
|
||||
) -> SchedulerPendingJob:
|
||||
return await asyncio.to_thread(
|
||||
self.submit,
|
||||
state,
|
||||
speed_factor,
|
||||
sample_steps,
|
||||
media_type,
|
||||
super_sampling,
|
||||
prepare_wall_ms,
|
||||
prepare_profile_total_ms,
|
||||
done_loop,
|
||||
done_future,
|
||||
engine_request_id,
|
||||
timeout_sec,
|
||||
skip_capacity_wait,
|
||||
admission_wait_ms_override,
|
||||
admission_snapshot_override,
|
||||
engine_policy_wait_ms,
|
||||
engine_dispatch_wait_ms,
|
||||
enqueue_pending,
|
||||
)
|
||||
|
||||
def snapshot(self) -> dict:
|
||||
with self.condition:
|
||||
prepare_state = self.prepare_executor.snapshot()
|
||||
finalize_state = self.finalize_executor.snapshot()
|
||||
shell_state = self.decode_legacy_shell.snapshot_locked()
|
||||
decode_runtime_counters = self.decode_runtime_tracker.get_counters()
|
||||
engine_owned_decode_state = bool(self.engine_decode_control_enabled)
|
||||
active_batch_summary = shell_state.get("executor_local_active_batch")
|
||||
executor_local_pending_jobs = int(shell_state.get("executor_local_pending_jobs", 0))
|
||||
executor_local_running_requests = int(shell_state.get("executor_local_running_requests", 0))
|
||||
executor_local_has_work = bool(shell_state.get("executor_local_has_work", False))
|
||||
return {
|
||||
"pending_jobs": 0 if engine_owned_decode_state else executor_local_pending_jobs,
|
||||
"running_requests": 0 if engine_owned_decode_state else executor_local_running_requests,
|
||||
"engine_decode_control_enabled": bool(self.engine_decode_control_enabled),
|
||||
"legacy_state_owner_mode": not engine_owned_decode_state,
|
||||
"decode_state_owner": "engine" if engine_owned_decode_state else "worker",
|
||||
"decode_runtime_has_work": False if engine_owned_decode_state else executor_local_has_work,
|
||||
"executor_local_pending_jobs": executor_local_pending_jobs,
|
||||
"executor_local_running_requests": executor_local_running_requests,
|
||||
"executor_local_has_work": executor_local_has_work,
|
||||
"decode_runtime_total_cycles": int(decode_runtime_counters.get("total_cycles", 0)),
|
||||
"decode_runtime_prefill_cycles": int(decode_runtime_counters.get("prefill_cycles", 0)),
|
||||
"decode_runtime_step_cycles": int(decode_runtime_counters.get("step_cycles", 0)),
|
||||
"prepare_inflight": prepare_state["inflight"],
|
||||
"prepare_peak_inflight": prepare_state["peak_inflight"],
|
||||
"prepare_max_inflight": prepare_state.get("max_inflight", 0),
|
||||
"prepare_state": dict(prepare_state),
|
||||
**finalize_state,
|
||||
"decode_backlog_max": self.decode_backlog_max,
|
||||
"finalize_pending_max": self.finalize_pending_max,
|
||||
"active_batch": {} if engine_owned_decode_state else active_batch_summary,
|
||||
"executor_local_active_batch": active_batch_summary if engine_owned_decode_state else None,
|
||||
"total_submitted": self.job_registry.submitted_count(),
|
||||
"total_finished": self.job_registry.finished_count(),
|
||||
"drained": self.is_drained(),
|
||||
}
|
||||
|
||||
def is_drained(self) -> bool:
|
||||
with self.condition:
|
||||
return (
|
||||
self.decode_legacy_shell.is_idle_locked()
|
||||
and self.job_registry.is_empty()
|
||||
and self.prepare_executor.is_idle()
|
||||
and self.finalize_executor.is_idle()
|
||||
)
|
||||
|
||||
def wait_until_idle(self, timeout_sec: float = 60.0, poll_interval_sec: float = 0.01) -> bool:
|
||||
deadline = time.perf_counter() + max(0.0, timeout_sec)
|
||||
while time.perf_counter() < deadline:
|
||||
if self.is_drained():
|
||||
return True
|
||||
time.sleep(poll_interval_sec)
|
||||
return self.is_drained()
|
||||
|
||||
def submit(
|
||||
self,
|
||||
state: T2SRequestState,
|
||||
speed_factor: float,
|
||||
sample_steps: int,
|
||||
media_type: str,
|
||||
super_sampling: bool,
|
||||
prepare_wall_ms: float,
|
||||
prepare_profile_total_ms: float,
|
||||
done_loop: asyncio.AbstractEventLoop | None = None,
|
||||
done_future: asyncio.Future | None = None,
|
||||
engine_request_id: str | None = None,
|
||||
timeout_sec: float | None = None,
|
||||
skip_capacity_wait: bool = False,
|
||||
admission_wait_ms_override: float | None = None,
|
||||
admission_snapshot_override: Dict[str, Any] | None = None,
|
||||
engine_policy_wait_ms: float = 0.0,
|
||||
engine_dispatch_wait_ms: float = 0.0,
|
||||
enqueue_pending: bool = True,
|
||||
) -> SchedulerPendingJob:
|
||||
if skip_capacity_wait:
|
||||
with self.condition:
|
||||
admission_snapshot = (
|
||||
dict(admission_snapshot_override)
|
||||
if admission_snapshot_override is not None
|
||||
else dict(self._admission_snapshot_locked())
|
||||
)
|
||||
admission_wait_ms = 0.0 if admission_wait_ms_override is None else float(admission_wait_ms_override)
|
||||
else:
|
||||
admission_wait_ms, admission_snapshot = self.wait_for_submit_capacity_blocking(timeout_sec=timeout_sec)
|
||||
job = SchedulerPendingJob(
|
||||
request_id=state.request_id,
|
||||
state=state,
|
||||
done_event=threading.Event(),
|
||||
done_loop=done_loop,
|
||||
done_future=done_future,
|
||||
enqueue_time=time.perf_counter(),
|
||||
speed_factor=float(speed_factor),
|
||||
sample_steps=int(sample_steps),
|
||||
media_type=media_type,
|
||||
super_sampling=bool(super_sampling),
|
||||
admission_wait_ms=float(admission_wait_ms),
|
||||
engine_policy_wait_ms=float(engine_policy_wait_ms),
|
||||
engine_dispatch_wait_ms=float(engine_dispatch_wait_ms),
|
||||
prepare_wall_ms=float(prepare_wall_ms),
|
||||
prepare_profile_total_ms=float(prepare_profile_total_ms),
|
||||
engine_request_id=engine_request_id or state.request_id,
|
||||
)
|
||||
with self.condition:
|
||||
self.job_registry.register(job, keep_job=not self.engine_decode_control_enabled)
|
||||
if enqueue_pending:
|
||||
self.decode_legacy_shell.enqueue_pending_job_locked(job)
|
||||
self.condition.notify_all()
|
||||
if enqueue_pending:
|
||||
self._notify_decode_runtime_state("submit")
|
||||
self._runtime_update(
|
||||
job.engine_request_id,
|
||||
EngineStatus.QUEUED,
|
||||
{
|
||||
"scheduler_request_id": job.request_id,
|
||||
"decode_admission_wait_ms": float(admission_wait_ms),
|
||||
"engine_policy_wait_ms": float(engine_policy_wait_ms),
|
||||
"engine_dispatch_wait_ms": float(engine_dispatch_wait_ms),
|
||||
"admission_snapshot": dict(admission_snapshot),
|
||||
},
|
||||
)
|
||||
return job
|
||||
|
||||
async def prepare_state_profiled_async(
|
||||
self,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.prepare_executor.prepare_state_profiled_async(spec, prepare_submit_at)
|
||||
|
||||
async def prepare_states_batch_async(self, specs: List[SchedulerRequestSpec]) -> List[T2SRequestState]:
|
||||
return await self.prepare_executor.prepare_states_batch_async(specs)
|
||||
|
||||
async def prepare_cpu_stage_profiled_async(
|
||||
self,
|
||||
spec: SchedulerRequestSpec,
|
||||
prepare_submit_at: float,
|
||||
) -> PreparedCpuStage:
|
||||
return await self.prepare_executor.prepare_cpu_stage_profiled_async(spec, prepare_submit_at)
|
||||
|
||||
async def prepare_gpu_stage_profiled_async(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return await self.prepare_executor.prepare_gpu_stage_profiled_async(cpu_stage)
|
||||
|
||||
async def prepare_gpu_stages_profiled_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[tuple[T2SRequestState, float, float] | Exception]:
|
||||
return await self.prepare_executor.prepare_gpu_stages_profiled_async(cpu_stages)
|
||||
|
||||
async def prepare_gpu_audio_phases_async(
|
||||
self,
|
||||
cpu_stages: List[PreparedCpuStage],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
return await self.prepare_executor.prepare_gpu_audio_phases_async(cpu_stages)
|
||||
|
||||
async def prepare_gpu_text_phases_async(
|
||||
self,
|
||||
items: List[tuple[PreparedCpuStage, Dict[str, Any]]],
|
||||
) -> List[Dict[str, Any] | Exception]:
|
||||
return await self.prepare_executor.prepare_gpu_text_phases_async(items)
|
||||
|
||||
def build_gpu_prepare_result_from_phases(
|
||||
self,
|
||||
cpu_stage: PreparedCpuStage,
|
||||
phase_one: Dict[str, Any],
|
||||
phase_two: Dict[str, Any],
|
||||
extra_profile: Dict[str, float] | None = None,
|
||||
) -> tuple[T2SRequestState, float, float]:
|
||||
return self.prepare_executor.build_gpu_prepare_result_from_phases(
|
||||
cpu_stage,
|
||||
phase_one,
|
||||
phase_two,
|
||||
extra_profile=extra_profile,
|
||||
)
|
||||
|
||||
async def prepare_ref_spec_stages_async(
|
||||
self,
|
||||
phase_ones: List[Dict[str, Any]],
|
||||
) -> List[tuple[tuple[Any, Any], Dict[str, float]] | Exception]:
|
||||
return await self.prepare_executor.prepare_ref_spec_stages_async(phase_ones)
|
||||
|
||||
def apply_ref_spec_result_to_state(
|
||||
self,
|
||||
state: T2SRequestState,
|
||||
ref_spec_result: tuple[tuple[Any, Any], Dict[str, float]],
|
||||
) -> None:
|
||||
self.prepare_executor.apply_ref_spec_result_to_state(state, ref_spec_result)
|
||||
@ -2,6 +2,7 @@ import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -1038,6 +1039,67 @@ class SynthesizerTrn(nn.Module):
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_batched_request_local(
|
||||
self,
|
||||
codes: torch.Tensor,
|
||||
code_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
refer_list: List[torch.Tensor],
|
||||
noise_scale: float = 0.5,
|
||||
speed: float = 1,
|
||||
sv_emb: torch.Tensor | None = None,
|
||||
):
|
||||
batch_size = int(codes.size(1))
|
||||
if batch_size <= 0:
|
||||
raise ValueError("decode_batched_request_local 收到空 batch")
|
||||
if len(refer_list) != batch_size:
|
||||
raise ValueError("refer_list 数量与 batch size 不一致")
|
||||
|
||||
refer_lengths = torch.LongTensor([int(item.size(2)) for item in refer_list]).to(codes.device)
|
||||
max_refer_len = int(refer_lengths.max().item())
|
||||
refer_batch = torch.zeros(
|
||||
(batch_size, int(refer_list[0].size(1)), max_refer_len),
|
||||
dtype=refer_list[0].dtype,
|
||||
device=codes.device,
|
||||
)
|
||||
for batch_index, refer in enumerate(refer_list):
|
||||
refer_batch[batch_index, :, : int(refer.size(2))] = refer.squeeze(0)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, max_refer_len), 1).to(refer_batch.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer_batch * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer_batch[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
if sv_emb is None:
|
||||
raise ValueError("v2Pro batched request-local synthesis 缺少 sv_emb")
|
||||
ge = ge + self.sv_emb(sv_emb).unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")
|
||||
y_lengths = code_lengths.to(device=codes.device, dtype=torch.long) * 2
|
||||
text_lengths = text_lengths.to(device=text.device, dtype=torch.long)
|
||||
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
audio = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
upsample_factor = 1
|
||||
for up_layer in self.dec.ups:
|
||||
stride = up_layer.stride[0] if isinstance(up_layer.stride, tuple) else int(up_layer.stride)
|
||||
upsample_factor *= int(stride)
|
||||
audio_lengths = y_mask.squeeze(1).sum(dim=1).to(dtype=torch.long) * int(upsample_factor)
|
||||
return audio, audio_lengths
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
import cn2an
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
@ -77,6 +78,205 @@ def g2p(text):
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
def _prepare_g2p_segments(segments):
|
||||
prepared_segments = []
|
||||
batch_inputs = []
|
||||
for segment in segments:
|
||||
processed_segment = re.sub("[a-zA-Z]+", "", segment)
|
||||
seg_cut = psg.lcut(processed_segment)
|
||||
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
||||
prepared_segments.append(
|
||||
{
|
||||
"segment": processed_segment,
|
||||
"seg_cut": seg_cut,
|
||||
}
|
||||
)
|
||||
if processed_segment:
|
||||
batch_inputs.append(processed_segment)
|
||||
return prepared_segments, batch_inputs
|
||||
|
||||
|
||||
def _build_segment_from_g2pw(segment: str, seg_cut, pinyins):
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
initials = []
|
||||
finals = []
|
||||
pre_word_length = 0
|
||||
for word, pos in seg_cut:
|
||||
sub_initials = []
|
||||
sub_finals = []
|
||||
now_word_length = pre_word_length + len(word)
|
||||
|
||||
if pos == "eng":
|
||||
pre_word_length = now_word_length
|
||||
continue
|
||||
|
||||
word_pinyins = pinyins[pre_word_length:now_word_length]
|
||||
word_pinyins = correct_pronunciation(word, word_pinyins)
|
||||
|
||||
for pinyin in word_pinyins:
|
||||
if pinyin[0].isalpha():
|
||||
sub_initials.append(to_initials(pinyin))
|
||||
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
|
||||
else:
|
||||
sub_initials.append(pinyin)
|
||||
sub_finals.append(pinyin)
|
||||
|
||||
pre_word_length = now_word_length
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
for c, v in zip(initials, finals):
|
||||
raw_pinyin = c + v
|
||||
if c == v:
|
||||
assert c in punctuation
|
||||
phone = [c]
|
||||
word2ph.append(1)
|
||||
else:
|
||||
v_without_tone = v[:-1]
|
||||
tone = v[-1]
|
||||
|
||||
pinyin = c + v_without_tone
|
||||
assert tone in "12345"
|
||||
|
||||
if c:
|
||||
v_rep_map = {
|
||||
"uei": "ui",
|
||||
"iou": "iu",
|
||||
"uen": "un",
|
||||
}
|
||||
if v_without_tone in v_rep_map.keys():
|
||||
pinyin = c + v_rep_map[v_without_tone]
|
||||
else:
|
||||
pinyin_rep_map = {
|
||||
"ing": "ying",
|
||||
"i": "yi",
|
||||
"in": "yin",
|
||||
"u": "wu",
|
||||
}
|
||||
if pinyin in pinyin_rep_map.keys():
|
||||
pinyin = pinyin_rep_map[pinyin]
|
||||
else:
|
||||
single_rep_map = {
|
||||
"v": "yu",
|
||||
"e": "e",
|
||||
"i": "y",
|
||||
"u": "w",
|
||||
}
|
||||
if pinyin[0] in single_rep_map.keys():
|
||||
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
||||
|
||||
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
|
||||
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
||||
new_v = new_v + tone
|
||||
phone = [new_c, new_v]
|
||||
word2ph.append(len(phone))
|
||||
|
||||
phones_list += phone
|
||||
return phones_list, word2ph
|
||||
|
||||
|
||||
def _build_segment_without_g2pw(segment: str, seg_cut):
|
||||
initials = []
|
||||
finals = []
|
||||
for word, pos in seg_cut:
|
||||
if pos == "eng":
|
||||
continue
|
||||
sub_initials, sub_finals = _get_initials_finals(word)
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
for c, v in zip(sum(initials, []), sum(finals, [])):
|
||||
raw_pinyin = c + v
|
||||
if c == v:
|
||||
assert c in punctuation
|
||||
phone = [c]
|
||||
word2ph.append(1)
|
||||
else:
|
||||
v_without_tone = v[:-1]
|
||||
tone = v[-1]
|
||||
pinyin = c + v_without_tone
|
||||
assert tone in "12345"
|
||||
if c:
|
||||
v_rep_map = {"uei": "ui", "iou": "iu", "uen": "un"}
|
||||
if v_without_tone in v_rep_map:
|
||||
pinyin = c + v_rep_map[v_without_tone]
|
||||
else:
|
||||
pinyin_rep_map = {"ing": "ying", "i": "yi", "in": "yin", "u": "wu"}
|
||||
if pinyin in pinyin_rep_map:
|
||||
pinyin = pinyin_rep_map[pinyin]
|
||||
else:
|
||||
single_rep_map = {"v": "yu", "e": "e", "i": "y", "u": "w"}
|
||||
if pinyin[0] in single_rep_map:
|
||||
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
||||
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, segment, raw_pinyin)
|
||||
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
||||
new_v = new_v + tone
|
||||
phone = [new_c, new_v]
|
||||
word2ph.append(len(phone))
|
||||
phones_list += phone
|
||||
return phones_list, word2ph
|
||||
|
||||
|
||||
def g2p_segments(segments, return_profile: bool = False):
|
||||
prepare_start = time.perf_counter()
|
||||
prepared_segments, batch_inputs = _prepare_g2p_segments(segments)
|
||||
profile = {
|
||||
"g2pw_prepare_ms": 0.0,
|
||||
"g2pw_predict_ms": 0.0,
|
||||
"g2pw_post_ms": 0.0,
|
||||
"g2pw_runtime_total_ms": 0.0,
|
||||
"g2pw_runtime_queue_wait_ms": 0.0,
|
||||
"g2pw_runtime_collect_wait_ms": 0.0,
|
||||
"g2pw_runtime_run_ms": 0.0,
|
||||
"g2pw_runtime_batch_rows": 0.0,
|
||||
"g2pw_runtime_batch_requests": 0.0,
|
||||
"g2pw_runtime_pool_workers": 0.0,
|
||||
"g2pw_runtime_shard_index": 0.0,
|
||||
}
|
||||
profile["g2pw_prepare_ms"] = float((time.perf_counter() - prepare_start) * 1000.0)
|
||||
if is_g2pw and batch_inputs:
|
||||
converter = g2pw._g2pw
|
||||
if hasattr(converter, "predict_sentences_with_profile"):
|
||||
g2pw_batch_results, predict_profile = converter.predict_sentences_with_profile(batch_inputs)
|
||||
for key, value in dict(predict_profile or {}).items():
|
||||
profile[key] = float(value)
|
||||
else:
|
||||
predict_start = time.perf_counter()
|
||||
g2pw_batch_results = converter(batch_inputs)
|
||||
profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_start) * 1000.0)
|
||||
else:
|
||||
g2pw_batch_results = []
|
||||
post_start = time.perf_counter()
|
||||
results = []
|
||||
batch_cursor = 0
|
||||
for item in prepared_segments:
|
||||
segment = item["segment"]
|
||||
if not segment:
|
||||
results.append(([], [], segment))
|
||||
continue
|
||||
if not is_g2pw:
|
||||
phones, word2ph = _build_segment_without_g2pw(segment, item["seg_cut"])
|
||||
results.append((phones, word2ph, segment))
|
||||
continue
|
||||
pinyins = g2pw_batch_results[batch_cursor]
|
||||
batch_cursor += 1
|
||||
phones, word2ph = _build_segment_from_g2pw(segment, item["seg_cut"], pinyins)
|
||||
results.append((phones, word2ph, segment))
|
||||
profile["g2pw_post_ms"] = float((time.perf_counter() - post_start) * 1000.0)
|
||||
profile["g2pw_total_ms"] = float(profile["g2pw_prepare_ms"] + profile["g2pw_predict_ms"] + profile["g2pw_post_ms"])
|
||||
if return_profile:
|
||||
return results, profile
|
||||
return results
|
||||
|
||||
|
||||
def _get_initials_finals(word):
|
||||
initials = []
|
||||
finals = []
|
||||
@ -180,118 +380,9 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
|
||||
def _g2p(segments):
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
for seg in segments:
|
||||
pinyins = []
|
||||
# Replace all English words in the sentence
|
||||
seg = re.sub("[a-zA-Z]+", "", seg)
|
||||
seg_cut = psg.lcut(seg)
|
||||
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
||||
initials = []
|
||||
finals = []
|
||||
|
||||
if not is_g2pw:
|
||||
for word, pos in seg_cut:
|
||||
if pos == "eng":
|
||||
continue
|
||||
sub_initials, sub_finals = _get_initials_finals(word)
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
# 儿化
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果", initials, finals)
|
||||
else:
|
||||
# g2pw采用整句推理
|
||||
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
|
||||
pre_word_length = 0
|
||||
for word, pos in seg_cut:
|
||||
sub_initials = []
|
||||
sub_finals = []
|
||||
now_word_length = pre_word_length + len(word)
|
||||
|
||||
if pos == "eng":
|
||||
pre_word_length = now_word_length
|
||||
continue
|
||||
|
||||
word_pinyins = pinyins[pre_word_length:now_word_length]
|
||||
|
||||
# 多音字消歧
|
||||
word_pinyins = correct_pronunciation(word, word_pinyins)
|
||||
|
||||
for pinyin in word_pinyins:
|
||||
if pinyin[0].isalpha():
|
||||
sub_initials.append(to_initials(pinyin))
|
||||
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
|
||||
else:
|
||||
sub_initials.append(pinyin)
|
||||
sub_finals.append(pinyin)
|
||||
|
||||
pre_word_length = now_word_length
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
# 儿化
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
# print("g2pw结果",initials,finals)
|
||||
|
||||
for c, v in zip(initials, finals):
|
||||
raw_pinyin = c + v
|
||||
# NOTE: post process for pypinyin outputs
|
||||
# we discriminate i, ii and iii
|
||||
if c == v:
|
||||
assert c in punctuation
|
||||
phone = [c]
|
||||
word2ph.append(1)
|
||||
else:
|
||||
v_without_tone = v[:-1]
|
||||
tone = v[-1]
|
||||
|
||||
pinyin = c + v_without_tone
|
||||
assert tone in "12345"
|
||||
|
||||
if c:
|
||||
# 多音节
|
||||
v_rep_map = {
|
||||
"uei": "ui",
|
||||
"iou": "iu",
|
||||
"uen": "un",
|
||||
}
|
||||
if v_without_tone in v_rep_map.keys():
|
||||
pinyin = c + v_rep_map[v_without_tone]
|
||||
else:
|
||||
# 单音节
|
||||
pinyin_rep_map = {
|
||||
"ing": "ying",
|
||||
"i": "yi",
|
||||
"in": "yin",
|
||||
"u": "wu",
|
||||
}
|
||||
if pinyin in pinyin_rep_map.keys():
|
||||
pinyin = pinyin_rep_map[pinyin]
|
||||
else:
|
||||
single_rep_map = {
|
||||
"v": "yu",
|
||||
"e": "e",
|
||||
"i": "y",
|
||||
"u": "w",
|
||||
}
|
||||
if pinyin[0] in single_rep_map.keys():
|
||||
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
||||
|
||||
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
|
||||
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
||||
new_v = new_v + tone
|
||||
phone = [new_c, new_v]
|
||||
word2ph.append(len(phone))
|
||||
|
||||
phones_list += phone
|
||||
for phones, item_word2ph, _segment in g2p_segments(segments):
|
||||
phones_list += phones
|
||||
word2ph += item_word2ph
|
||||
return phones_list, word2ph
|
||||
|
||||
|
||||
|
||||
685
GPT_SoVITS/text/g2pw/cuda_api.py
Normal file
685
GPT_SoVITS/text/g2pw/cuda_api.py
Normal file
@ -0,0 +1,685 @@
|
||||
import ctypes
|
||||
import fcntl
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Deque, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .onnx_api import _G2PWBaseOnnxConverter
|
||||
|
||||
|
||||
class G2PWCudaError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class G2PWBatchTask:
|
||||
model_input: Dict[str, np.ndarray]
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
enqueued_at: float = 0.0
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
output: np.ndarray | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
error: Exception | None = None
|
||||
|
||||
|
||||
_ROOT_DIR = Path(__file__).resolve().parents[3]
|
||||
_PACKAGE_DIR = Path(__file__).resolve().parent
|
||||
_OUTPUT_DIR = _ROOT_DIR / "outputs" / "g2pw_cuda_bridge"
|
||||
_WRAPPER_SOURCE = _PACKAGE_DIR / "g2pw_cuda_bridge.cpp"
|
||||
_LOCK_PATH = _OUTPUT_DIR / "build.lock"
|
||||
|
||||
|
||||
def _env_flag(name: str, default: bool) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return 1 if default else 0
|
||||
return 0 if raw.strip().lower() in {"0", "false", "no", "off"} else 1
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None or raw.strip() == "":
|
||||
return int(default)
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _resolve_cuda_root() -> Path:
|
||||
env_root = os.environ.get("GPTSOVITS_G2PW_CUDA_ROOT", "").strip()
|
||||
candidates = [
|
||||
env_root,
|
||||
_ROOT_DIR / "third_party" / "g2pw-cu",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if not candidate:
|
||||
continue
|
||||
path = Path(candidate).expanduser().resolve()
|
||||
if path.exists():
|
||||
return path
|
||||
checked = [
|
||||
str(Path(candidate).expanduser().resolve())
|
||||
for candidate in candidates
|
||||
if str(candidate).strip() != ""
|
||||
]
|
||||
raise G2PWCudaError(
|
||||
"Cannot locate g2pw-cu root. "
|
||||
"Expected one of: "
|
||||
f"{checked}. "
|
||||
"Recommended: clone https://github.com/baicai-1145/g2pw-cu.git into "
|
||||
f"{(_ROOT_DIR / 'third_party' / 'g2pw-cu').as_posix()} "
|
||||
"or set GPTSOVITS_G2PW_CUDA_ROOT explicitly."
|
||||
)
|
||||
|
||||
|
||||
def _resolve_runtime_paths() -> tuple[Path, Path, Path]:
|
||||
cuda_root = _resolve_cuda_root()
|
||||
runtime_lib = Path(
|
||||
os.environ.get("GPTSOVITS_G2PW_CUDA_RUNTIME_LIB", str(cuda_root / "build" / "libg2pw_runtime.so"))
|
||||
).expanduser()
|
||||
manifest_path = Path(
|
||||
os.environ.get("GPTSOVITS_G2PW_CUDA_MANIFEST", str(cuda_root / "artifacts" / "model" / "manifest.txt"))
|
||||
).expanduser()
|
||||
weights_path = Path(
|
||||
os.environ.get("GPTSOVITS_G2PW_CUDA_WEIGHTS", str(cuda_root / "artifacts" / "model" / "weights.bin"))
|
||||
).expanduser()
|
||||
for path in (runtime_lib, manifest_path, weights_path):
|
||||
if not path.exists():
|
||||
raise G2PWCudaError(f"Missing g2pw-cu artifact: {path}")
|
||||
return runtime_lib.resolve(), manifest_path.resolve(), weights_path.resolve()
|
||||
|
||||
|
||||
def _build_bridge(wrapper_output: Path, runtime_lib: Path) -> None:
|
||||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
compile_cmd = [
|
||||
os.environ.get("CXX", "g++"),
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-shared",
|
||||
"-fPIC",
|
||||
str(_WRAPPER_SOURCE),
|
||||
"-I",
|
||||
str(runtime_lib.parent.parent / "include"),
|
||||
"-L",
|
||||
str(runtime_lib.parent),
|
||||
"-lg2pw_runtime",
|
||||
f"-Wl,-rpath,{runtime_lib.parent}",
|
||||
"-o",
|
||||
str(wrapper_output),
|
||||
]
|
||||
result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode != 0:
|
||||
raise G2PWCudaError(
|
||||
"Failed to build g2pw-cu bridge:\n"
|
||||
f"cmd={' '.join(compile_cmd)}\n"
|
||||
f"stdout={result.stdout}\n"
|
||||
f"stderr={result.stderr}"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_bridge_built(runtime_lib: Path) -> Path:
|
||||
wrapper_output = _OUTPUT_DIR / "g2pw_cuda_bridge.so"
|
||||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with _LOCK_PATH.open("w", encoding="utf-8") as lock_file:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
needs_build = not wrapper_output.exists()
|
||||
if not needs_build:
|
||||
so_mtime = wrapper_output.stat().st_mtime
|
||||
needs_build = so_mtime < _WRAPPER_SOURCE.stat().st_mtime or so_mtime < runtime_lib.stat().st_mtime
|
||||
if needs_build:
|
||||
tmp_output = wrapper_output.with_suffix(".tmp.so")
|
||||
if tmp_output.exists():
|
||||
tmp_output.unlink()
|
||||
_build_bridge(tmp_output, runtime_lib)
|
||||
tmp_output.replace(wrapper_output)
|
||||
return wrapper_output
|
||||
|
||||
|
||||
def _load_bridge():
|
||||
runtime_lib, manifest_path, weights_path = _resolve_runtime_paths()
|
||||
bridge_path = _ensure_bridge_built(runtime_lib)
|
||||
global_mode = getattr(ctypes, "RTLD_GLOBAL", getattr(os, "RTLD_GLOBAL", 0))
|
||||
ctypes.CDLL(str(runtime_lib), mode=global_mode)
|
||||
lib = ctypes.CDLL(str(bridge_path))
|
||||
lib.g2pw_runtime_create.argtypes = [
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
]
|
||||
lib.g2pw_runtime_create.restype = ctypes.c_void_p
|
||||
lib.g2pw_runtime_destroy.argtypes = [ctypes.c_void_p]
|
||||
lib.g2pw_runtime_destroy.restype = None
|
||||
lib.g2pw_runtime_last_error.argtypes = [ctypes.c_void_p]
|
||||
lib.g2pw_runtime_last_error.restype = ctypes.c_char_p
|
||||
lib.g2pw_runtime_num_labels.argtypes = [ctypes.c_void_p]
|
||||
lib.g2pw_runtime_num_labels.restype = ctypes.c_int
|
||||
lib.g2pw_runtime_run.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_void_p,
|
||||
]
|
||||
lib.g2pw_runtime_run.restype = ctypes.c_int
|
||||
return lib, manifest_path, weights_path, runtime_lib
|
||||
|
||||
|
||||
def _gemm_precision_value() -> int:
|
||||
precision = os.environ.get("GPTSOVITS_G2PW_CUDA_GEMM_PRECISION", "fp32").strip().lower()
|
||||
if precision == "fp16":
|
||||
return 1
|
||||
if precision == "bf16":
|
||||
return 2
|
||||
return 0
|
||||
|
||||
|
||||
class G2PWRuntimeWrapper:
|
||||
def __init__(self, shard_index: int = 0) -> None:
|
||||
self.lib, self.manifest_path, self.weights_path, self.runtime_lib = _load_bridge()
|
||||
self.shard_index = int(shard_index)
|
||||
self.device_ordinal = _env_int("GPTSOVITS_G2PW_CUDA_DEVICE", 0)
|
||||
self.allow_tensor_cores = _env_flag("GPTSOVITS_G2PW_CUDA_ALLOW_TENSOR_CORES", False)
|
||||
self.use_cublaslt_bias_epilogue = _env_flag("GPTSOVITS_G2PW_CUDA_USE_CUBLASLT_BIAS_EPILOGUE", False)
|
||||
self.enable_profiling = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_PROFILE", False)
|
||||
self.enable_cuda_graph = _env_flag("GPTSOVITS_G2PW_CUDA_ENABLE_GRAPH", True)
|
||||
self.dump_graph_cache_stats = _env_flag("GPTSOVITS_G2PW_CUDA_DUMP_GRAPH_CACHE_STATS", False)
|
||||
self.full_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_FULL_GRAPH_CACHE_LIMIT", 0)
|
||||
self.tail_graph_cache_limit = _env_int("GPTSOVITS_G2PW_CUDA_TAIL_GRAPH_CACHE_LIMIT", 0)
|
||||
self.gemm_precision = _gemm_precision_value()
|
||||
self.lock = threading.Lock()
|
||||
self.handle = None
|
||||
self.max_batch_size = 0
|
||||
self.max_seq_len = 0
|
||||
self.num_labels = 0
|
||||
self.batch_enabled = _env_flag("GPTSOVITS_G2PW_CUDA_BATCHING", True) != 0
|
||||
self.batch_window_s = max(0.0, float(_env_int("GPTSOVITS_G2PW_CUDA_BATCH_WINDOW_MS", 1)) / 1000.0)
|
||||
self.batch_max_requests = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_REQUESTS", 64))
|
||||
self.batch_max_rows = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_ROWS", 96))
|
||||
self.batch_max_tokens = max(1, _env_int("GPTSOVITS_G2PW_CUDA_BATCH_MAX_TOKENS", 4096))
|
||||
self.batch_condition = threading.Condition()
|
||||
self.pending_tasks: Deque[G2PWBatchTask] = deque()
|
||||
self.batch_total_tasks = 0
|
||||
self.batch_total_batches = 0
|
||||
self.batch_total_rows = 0
|
||||
self.batch_total_queue_wait_ms = 0.0
|
||||
self.batch_queue_wait_peak_ms = 0.0
|
||||
self.batch_total_collect_wait_ms = 0.0
|
||||
self.batch_collect_wait_peak_ms = 0.0
|
||||
self.batch_total_run_ms = 0.0
|
||||
self.batch_run_peak_ms = 0.0
|
||||
self.batch_rows_peak = 0
|
||||
self.batch_requests_peak = 0
|
||||
self.batch_pending_peak = 0
|
||||
self.closed = False
|
||||
self._ensure_capacity(
|
||||
batch_size=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_BATCH_SIZE", 256)),
|
||||
seq_len=max(1, _env_int("GPTSOVITS_G2PW_CUDA_MAX_SEQ_LEN", 128)),
|
||||
)
|
||||
self.batch_worker = None
|
||||
if self.batch_enabled:
|
||||
self.batch_worker = threading.Thread(
|
||||
target=self._batch_loop,
|
||||
name=f"g2pw-cuda-batch-worker-{self.shard_index}",
|
||||
daemon=True,
|
||||
)
|
||||
self.batch_worker.start()
|
||||
|
||||
def _sync_runtime_env_overrides(self) -> None:
|
||||
os.environ["G2PW_ENABLE_CUDA_GRAPH"] = "1" if self.enable_cuda_graph else "0"
|
||||
os.environ["G2PW_ENABLE_PROFILE"] = "1" if self.enable_profiling else "0"
|
||||
os.environ["G2PW_DUMP_GRAPH_CACHE_STATS"] = "1" if self.dump_graph_cache_stats else "0"
|
||||
os.environ["G2PW_FULL_GRAPH_CACHE_LIMIT"] = str(int(self.full_graph_cache_limit))
|
||||
os.environ["G2PW_TAIL_GRAPH_CACHE_LIMIT"] = str(int(self.tail_graph_cache_limit))
|
||||
os.environ["G2PW_ALLOW_TENSOR_CORES"] = "1" if self.allow_tensor_cores else "0"
|
||||
os.environ["G2PW_USE_CUBLASLT_BIAS_EPILOGUE"] = "1" if self.use_cublaslt_bias_epilogue else "0"
|
||||
os.environ["G2PW_GEMM_PRECISION"] = {0: "fp32", 1: "fp16", 2: "bf16"}.get(int(self.gemm_precision), "fp32")
|
||||
|
||||
def _destroy_handle(self) -> None:
|
||||
if self.handle:
|
||||
self.lib.g2pw_runtime_destroy(self.handle)
|
||||
self.handle = None
|
||||
|
||||
def close(self) -> None:
|
||||
with self.batch_condition:
|
||||
self.closed = True
|
||||
self.batch_condition.notify_all()
|
||||
self._destroy_handle()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _last_error(self) -> str:
|
||||
if not self.handle:
|
||||
return "uninitialized runtime"
|
||||
message = self.lib.g2pw_runtime_last_error(self.handle)
|
||||
return "" if not message else message.decode("utf-8", errors="replace")
|
||||
|
||||
def _create_handle(self, batch_size: int, seq_len: int) -> None:
|
||||
self._sync_runtime_env_overrides()
|
||||
new_handle = self.lib.g2pw_runtime_create(
|
||||
str(self.manifest_path).encode("utf-8"),
|
||||
str(self.weights_path).encode("utf-8"),
|
||||
int(self.device_ordinal),
|
||||
int(batch_size),
|
||||
int(seq_len),
|
||||
int(self.full_graph_cache_limit),
|
||||
int(self.tail_graph_cache_limit),
|
||||
int(self.allow_tensor_cores),
|
||||
int(self.use_cublaslt_bias_epilogue),
|
||||
int(self.enable_profiling),
|
||||
int(self.enable_cuda_graph),
|
||||
int(self.dump_graph_cache_stats),
|
||||
int(self.gemm_precision),
|
||||
)
|
||||
if not new_handle:
|
||||
raise G2PWCudaError("g2pw-cu returned null runtime handle")
|
||||
self.handle = new_handle
|
||||
self.max_batch_size = int(batch_size)
|
||||
self.max_seq_len = int(seq_len)
|
||||
self.num_labels = int(self.lib.g2pw_runtime_num_labels(self.handle))
|
||||
last_error = self._last_error()
|
||||
if self.num_labels <= 0 or last_error:
|
||||
self.close()
|
||||
raise G2PWCudaError(f"Failed to initialize g2pw-cu runtime: {last_error or 'num_labels <= 0'}")
|
||||
|
||||
def _ensure_capacity(self, batch_size: int, seq_len: int) -> None:
|
||||
target_batch = max(1, int(batch_size))
|
||||
target_seq = max(1, int(seq_len))
|
||||
if self.handle and target_batch <= self.max_batch_size and target_seq <= self.max_seq_len:
|
||||
return
|
||||
next_batch = max(target_batch, self.max_batch_size * 2 if self.max_batch_size else 0)
|
||||
next_seq = max(target_seq, self.max_seq_len * 2 if self.max_seq_len else 0)
|
||||
self._destroy_handle()
|
||||
self._create_handle(batch_size=next_batch, seq_len=next_seq)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_input(model_input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
||||
input_ids = np.ascontiguousarray(model_input["input_ids"], dtype=np.int64)
|
||||
token_type_ids = np.ascontiguousarray(model_input["token_type_ids"], dtype=np.int64)
|
||||
attention_masks = np.ascontiguousarray(model_input["attention_masks"], dtype=np.int64)
|
||||
phoneme_masks = np.ascontiguousarray(model_input["phoneme_masks"], dtype=np.float32)
|
||||
char_ids = np.ascontiguousarray(model_input["char_ids"], dtype=np.int64)
|
||||
position_ids = np.ascontiguousarray(model_input["position_ids"], dtype=np.int64)
|
||||
batch_size = int(char_ids.shape[0])
|
||||
if input_ids.shape[0] == 1 and batch_size > 1:
|
||||
input_ids = np.ascontiguousarray(np.repeat(input_ids, batch_size, axis=0), dtype=np.int64)
|
||||
token_type_ids = np.ascontiguousarray(np.repeat(token_type_ids, batch_size, axis=0), dtype=np.int64)
|
||||
attention_masks = np.ascontiguousarray(np.repeat(attention_masks, batch_size, axis=0), dtype=np.int64)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_masks": attention_masks,
|
||||
"phoneme_masks": phoneme_masks,
|
||||
"char_ids": char_ids,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def _run_direct(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
normalized = self._normalize_model_input(model_input)
|
||||
input_ids = normalized["input_ids"]
|
||||
token_type_ids = normalized["token_type_ids"]
|
||||
attention_masks = normalized["attention_masks"]
|
||||
phoneme_masks = normalized["phoneme_masks"]
|
||||
char_ids = normalized["char_ids"]
|
||||
position_ids = normalized["position_ids"]
|
||||
batch_size = int(char_ids.shape[0])
|
||||
seq_len = int(input_ids.shape[1])
|
||||
probs = np.empty((batch_size, self.num_labels), dtype=np.float32)
|
||||
with self.lock:
|
||||
self._ensure_capacity(batch_size=batch_size, seq_len=seq_len)
|
||||
status = self.lib.g2pw_runtime_run(
|
||||
self.handle,
|
||||
input_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
token_type_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
attention_masks.ctypes.data_as(ctypes.c_void_p),
|
||||
phoneme_masks.ctypes.data_as(ctypes.c_void_p),
|
||||
char_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
position_ids.ctypes.data_as(ctypes.c_void_p),
|
||||
batch_size,
|
||||
seq_len,
|
||||
probs.ctypes.data_as(ctypes.c_void_p),
|
||||
)
|
||||
if int(status) != 0:
|
||||
raise G2PWCudaError(f"g2pw-cu inference failed: {self._last_error()}")
|
||||
return probs
|
||||
|
||||
def _can_append_task(self, tasks: List[G2PWBatchTask], candidate: G2PWBatchTask) -> bool:
|
||||
request_count = len(tasks) + 1
|
||||
if request_count > self.batch_max_requests:
|
||||
return False
|
||||
total_rows = sum(int(item.model_input["char_ids"].shape[0]) for item in tasks) + int(
|
||||
candidate.model_input["char_ids"].shape[0]
|
||||
)
|
||||
if total_rows > self.batch_max_rows:
|
||||
return False
|
||||
total_tokens = sum(
|
||||
int(item.model_input["char_ids"].shape[0]) * int(item.model_input["input_ids"].shape[1]) for item in tasks
|
||||
) + int(candidate.model_input["char_ids"].shape[0]) * int(candidate.model_input["input_ids"].shape[1])
|
||||
return total_tokens <= self.batch_max_tokens
|
||||
|
||||
def _merge_batch_inputs(self, tasks: List[G2PWBatchTask]) -> Tuple[Dict[str, np.ndarray], List[Tuple[int, int]]]:
|
||||
normalized_inputs = [self._normalize_model_input(task.model_input) for task in tasks]
|
||||
total_rows = sum(int(item["char_ids"].shape[0]) for item in normalized_inputs)
|
||||
max_seq_len = max(int(item["input_ids"].shape[1]) for item in normalized_inputs)
|
||||
input_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
|
||||
token_type_ids = np.zeros((total_rows, max_seq_len), dtype=np.int64)
|
||||
attention_masks = np.zeros((total_rows, max_seq_len), dtype=np.int64)
|
||||
phoneme_masks = np.zeros((total_rows, normalized_inputs[0]["phoneme_masks"].shape[1]), dtype=np.float32)
|
||||
char_ids = np.zeros((total_rows,), dtype=np.int64)
|
||||
position_ids = np.zeros((total_rows,), dtype=np.int64)
|
||||
slices: List[Tuple[int, int]] = []
|
||||
cursor = 0
|
||||
for item in normalized_inputs:
|
||||
rows = int(item["char_ids"].shape[0])
|
||||
seq_len = int(item["input_ids"].shape[1])
|
||||
next_cursor = cursor + rows
|
||||
input_ids[cursor:next_cursor, :seq_len] = item["input_ids"]
|
||||
token_type_ids[cursor:next_cursor, :seq_len] = item["token_type_ids"]
|
||||
attention_masks[cursor:next_cursor, :seq_len] = item["attention_masks"]
|
||||
phoneme_masks[cursor:next_cursor] = item["phoneme_masks"]
|
||||
char_ids[cursor:next_cursor] = item["char_ids"]
|
||||
position_ids[cursor:next_cursor] = item["position_ids"]
|
||||
slices.append((cursor, next_cursor))
|
||||
cursor = next_cursor
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_masks": attention_masks,
|
||||
"phoneme_masks": phoneme_masks,
|
||||
"char_ids": char_ids,
|
||||
"position_ids": position_ids,
|
||||
}, slices
|
||||
|
||||
def _finish_task(
|
||||
self,
|
||||
task: G2PWBatchTask,
|
||||
output: np.ndarray | None = None,
|
||||
profile: Dict[str, float] | None = None,
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
task.output = output
|
||||
task.profile = dict(profile or {})
|
||||
task.error = error
|
||||
task.done_event.set()
|
||||
|
||||
def _batch_loop(self) -> None:
|
||||
while True:
|
||||
with self.batch_condition:
|
||||
while not self.pending_tasks and not self.closed:
|
||||
self.batch_condition.wait()
|
||||
if self.closed and not self.pending_tasks:
|
||||
return
|
||||
first_task = self.pending_tasks.popleft()
|
||||
batch_tasks = [first_task]
|
||||
collect_started = time.perf_counter()
|
||||
deadline = collect_started + self.batch_window_s
|
||||
while True:
|
||||
if len(batch_tasks) >= self.batch_max_requests:
|
||||
break
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0.0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.batch_condition.wait(timeout=remaining)
|
||||
continue
|
||||
candidate = self.pending_tasks[0]
|
||||
if not self._can_append_task(batch_tasks, candidate):
|
||||
break
|
||||
batch_tasks.append(self.pending_tasks.popleft())
|
||||
collect_wait_ms = max(0.0, (time.perf_counter() - collect_started) * 1000.0)
|
||||
|
||||
now = time.perf_counter()
|
||||
queue_wait_values = [max(0.0, (now - task.enqueued_at) * 1000.0) for task in batch_tasks]
|
||||
try:
|
||||
merged_input, row_slices = self._merge_batch_inputs(batch_tasks)
|
||||
run_started = time.perf_counter()
|
||||
merged_output = self._run_direct(merged_input)
|
||||
run_ms = max(0.0, (time.perf_counter() - run_started) * 1000.0)
|
||||
for task, (start, end) in zip(batch_tasks, row_slices):
|
||||
task_rows = int(task.model_input["char_ids"].shape[0])
|
||||
task_seq_len = int(task.model_input["input_ids"].shape[1])
|
||||
self._finish_task(
|
||||
task,
|
||||
output=np.ascontiguousarray(merged_output[start:end]),
|
||||
profile={
|
||||
"g2pw_runtime_queue_wait_ms": float(max(0.0, (run_started - task.enqueued_at) * 1000.0)),
|
||||
"g2pw_runtime_collect_wait_ms": float(collect_wait_ms),
|
||||
"g2pw_runtime_run_ms": float(run_ms),
|
||||
"g2pw_runtime_batch_rows": float(sum(int(item.model_input["char_ids"].shape[0]) for item in batch_tasks)),
|
||||
"g2pw_runtime_batch_requests": float(len(batch_tasks)),
|
||||
"g2pw_runtime_task_rows": float(task_rows),
|
||||
"g2pw_runtime_task_seq_len": float(task_seq_len),
|
||||
"g2pw_runtime_shard_index": float(self.shard_index),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
run_ms = 0.0
|
||||
for task in batch_tasks:
|
||||
self._finish_task(task, error=exc)
|
||||
finally:
|
||||
with self.batch_condition:
|
||||
self.batch_total_batches += 1
|
||||
self.batch_total_tasks += len(batch_tasks)
|
||||
self.batch_total_rows += sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
|
||||
self.batch_total_queue_wait_ms += float(sum(queue_wait_values))
|
||||
self.batch_queue_wait_peak_ms = max(self.batch_queue_wait_peak_ms, max(queue_wait_values or [0.0]))
|
||||
self.batch_total_collect_wait_ms += float(collect_wait_ms) * float(len(batch_tasks))
|
||||
self.batch_collect_wait_peak_ms = max(self.batch_collect_wait_peak_ms, float(collect_wait_ms))
|
||||
self.batch_total_run_ms += float(run_ms)
|
||||
self.batch_run_peak_ms = max(self.batch_run_peak_ms, float(run_ms))
|
||||
self.batch_rows_peak = max(
|
||||
self.batch_rows_peak, sum(int(task.model_input["char_ids"].shape[0]) for task in batch_tasks)
|
||||
)
|
||||
self.batch_requests_peak = max(self.batch_requests_peak, len(batch_tasks))
|
||||
|
||||
def _submit_batched(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
|
||||
task = G2PWBatchTask(model_input=model_input)
|
||||
with self.batch_condition:
|
||||
if self.closed:
|
||||
raise G2PWCudaError("g2pw-cu batch worker already closed")
|
||||
task.enqueued_at = time.perf_counter()
|
||||
self.pending_tasks.append(task)
|
||||
self.batch_pending_peak = max(self.batch_pending_peak, len(self.pending_tasks))
|
||||
self.batch_condition.notify_all()
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.output is not None
|
||||
return task.output, dict(task.profile)
|
||||
|
||||
def snapshot(self) -> Dict[str, float | int | bool]:
|
||||
with self.batch_condition:
|
||||
average_tasks_per_batch = (
|
||||
float(self.batch_total_tasks) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
|
||||
)
|
||||
average_rows_per_batch = (
|
||||
float(self.batch_total_rows) / float(self.batch_total_batches) if self.batch_total_batches > 0 else 0.0
|
||||
)
|
||||
average_queue_wait_ms = (
|
||||
float(self.batch_total_queue_wait_ms) / float(self.batch_total_tasks) if self.batch_total_tasks > 0 else 0.0
|
||||
)
|
||||
average_collect_wait_ms = (
|
||||
float(self.batch_total_collect_wait_ms) / float(self.batch_total_tasks)
|
||||
if self.batch_total_tasks > 0
|
||||
else 0.0
|
||||
)
|
||||
return {
|
||||
"shard_index": int(self.shard_index),
|
||||
"enabled": bool(self.batch_enabled),
|
||||
"enable_cuda_graph": bool(self.enable_cuda_graph),
|
||||
"enable_profiling": bool(self.enable_profiling),
|
||||
"full_graph_cache_limit": int(self.full_graph_cache_limit),
|
||||
"tail_graph_cache_limit": int(self.tail_graph_cache_limit),
|
||||
"window_ms": float(self.batch_window_s * 1000.0),
|
||||
"max_requests": int(self.batch_max_requests),
|
||||
"max_rows": int(self.batch_max_rows),
|
||||
"max_tokens": int(self.batch_max_tokens),
|
||||
"pending": int(len(self.pending_tasks)),
|
||||
"pending_peak": int(self.batch_pending_peak),
|
||||
"total_batches": int(self.batch_total_batches),
|
||||
"total_tasks": int(self.batch_total_tasks),
|
||||
"total_rows": int(self.batch_total_rows),
|
||||
"avg_tasks_per_batch": float(average_tasks_per_batch),
|
||||
"avg_rows_per_batch": float(average_rows_per_batch),
|
||||
"avg_queue_wait_ms": float(average_queue_wait_ms),
|
||||
"queue_wait_peak_ms": float(self.batch_queue_wait_peak_ms),
|
||||
"avg_collect_wait_ms": float(average_collect_wait_ms),
|
||||
"collect_wait_peak_ms": float(self.batch_collect_wait_peak_ms),
|
||||
"run_total_ms": float(self.batch_total_run_ms),
|
||||
"run_peak_ms": float(self.batch_run_peak_ms),
|
||||
"batch_rows_peak": int(self.batch_rows_peak),
|
||||
"batch_requests_peak": int(self.batch_requests_peak),
|
||||
}
|
||||
|
||||
def pending_rows(self) -> int:
|
||||
with self.batch_condition:
|
||||
return int(sum(int(task.model_input["char_ids"].shape[0]) for task in self.pending_tasks))
|
||||
|
||||
def pending_count(self) -> int:
|
||||
with self.batch_condition:
|
||||
return int(len(self.pending_tasks))
|
||||
|
||||
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
|
||||
if not self.batch_enabled:
|
||||
started = time.perf_counter()
|
||||
output = self._run_direct(model_input)
|
||||
return output, {
|
||||
"g2pw_runtime_queue_wait_ms": 0.0,
|
||||
"g2pw_runtime_collect_wait_ms": 0.0,
|
||||
"g2pw_runtime_run_ms": float((time.perf_counter() - started) * 1000.0),
|
||||
"g2pw_runtime_batch_rows": float(model_input["char_ids"].shape[0]),
|
||||
"g2pw_runtime_batch_requests": 1.0,
|
||||
"g2pw_runtime_task_rows": float(model_input["char_ids"].shape[0]),
|
||||
"g2pw_runtime_task_seq_len": float(model_input["input_ids"].shape[1]),
|
||||
"g2pw_runtime_shard_index": float(self.shard_index),
|
||||
}
|
||||
return self._submit_batched(model_input)
|
||||
|
||||
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
output, _profile = self.run_with_profile(model_input)
|
||||
return output
|
||||
|
||||
|
||||
class G2PWRuntimePool:
|
||||
def __init__(self) -> None:
|
||||
self.worker_count = max(1, _env_int("GPTSOVITS_G2PW_CUDA_WORKERS", 2))
|
||||
self.shards = [G2PWRuntimeWrapper(shard_index=index) for index in range(self.worker_count)]
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _pick_shard(self) -> G2PWRuntimeWrapper:
|
||||
with self.lock:
|
||||
return min(
|
||||
self.shards,
|
||||
key=lambda shard: (
|
||||
shard.pending_rows(),
|
||||
shard.pending_count(),
|
||||
shard.snapshot().get("avg_queue_wait_ms", 0.0),
|
||||
),
|
||||
)
|
||||
|
||||
def run_with_profile(self, model_input: Dict[str, np.ndarray]) -> tuple[np.ndarray, Dict[str, float]]:
|
||||
shard = self._pick_shard()
|
||||
output, profile = shard.run_with_profile(model_input)
|
||||
profile["g2pw_runtime_pool_workers"] = float(self.worker_count)
|
||||
return output, profile
|
||||
|
||||
def run(self, model_input: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
output, _profile = self.run_with_profile(model_input)
|
||||
return output
|
||||
|
||||
def snapshot(self) -> Dict[str, float | int | bool | List[Dict[str, float | int | bool]]]:
|
||||
shard_snapshots = [dict(shard.snapshot()) for shard in self.shards]
|
||||
avg_queue_wait_ms = 0.0
|
||||
total_tasks = 0.0
|
||||
pending = 0
|
||||
pending_peak = 0
|
||||
total_batches = 0
|
||||
total_rows = 0
|
||||
batch_rows_peak = 0
|
||||
batch_requests_peak = 0
|
||||
for snapshot in shard_snapshots:
|
||||
tasks = float(snapshot.get("total_tasks", 0.0))
|
||||
avg_queue_wait_ms += float(snapshot.get("avg_queue_wait_ms", 0.0)) * tasks
|
||||
total_tasks += tasks
|
||||
pending += int(snapshot.get("pending", 0))
|
||||
pending_peak = max(pending_peak, int(snapshot.get("pending_peak", 0)))
|
||||
total_batches += int(snapshot.get("total_batches", 0))
|
||||
total_rows += int(snapshot.get("total_rows", 0))
|
||||
batch_rows_peak = max(batch_rows_peak, int(snapshot.get("batch_rows_peak", 0)))
|
||||
batch_requests_peak = max(batch_requests_peak, int(snapshot.get("batch_requests_peak", 0)))
|
||||
return {
|
||||
"worker_count": int(self.worker_count),
|
||||
"pending": int(pending),
|
||||
"pending_peak": int(pending_peak),
|
||||
"total_batches": int(total_batches),
|
||||
"total_tasks": int(total_tasks),
|
||||
"total_rows": int(total_rows),
|
||||
"avg_queue_wait_ms": float(avg_queue_wait_ms / total_tasks) if total_tasks > 0 else 0.0,
|
||||
"batch_rows_peak": int(batch_rows_peak),
|
||||
"batch_requests_peak": int(batch_requests_peak),
|
||||
"shards": shard_snapshots,
|
||||
}
|
||||
|
||||
|
||||
class G2PWCudaConverter(_G2PWBaseOnnxConverter):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
style=style,
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
self.runtime = G2PWRuntimePool()
|
||||
self.backend = "cuda"
|
||||
primary_runtime = self.runtime.shards[0]
|
||||
self.device = f"cuda:{primary_runtime.device_ordinal}"
|
||||
self.checkpoint_path = str(primary_runtime.weights_path)
|
||||
self.providers = ["g2pw-cu"]
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
probs = self.runtime.run(model_input)
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
|
||||
return [self.labels[pred] for pred in preds], confidences
|
||||
|
||||
def _predict_with_profile(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float], Dict[str, float]]:
|
||||
started = time.perf_counter()
|
||||
probs, runtime_profile = self.runtime.run_with_profile(model_input)
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
confidences = probs[np.arange(len(preds)), preds].astype(np.float32, copy=False).tolist()
|
||||
profile = dict(runtime_profile)
|
||||
profile["g2pw_runtime_total_ms"] = float((time.perf_counter() - started) * 1000.0)
|
||||
profile["g2pw_predict_ms"] = float(profile["g2pw_runtime_total_ms"])
|
||||
return [self.labels[pred] for pred in preds], confidences, profile
|
||||
|
||||
def snapshot(self) -> Dict[str, float | int | bool]:
|
||||
return dict(self.runtime.snapshot())
|
||||
@ -18,6 +18,7 @@ Credits
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -37,6 +38,8 @@ def prepare_onnx_input(
|
||||
use_mask: bool = False,
|
||||
window_size: int = None,
|
||||
max_len: int = 512,
|
||||
char2id: Optional[Dict[str, int]] = None,
|
||||
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
|
||||
) -> Dict[str, np.array]:
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||
@ -48,33 +51,88 @@ def prepare_onnx_input(
|
||||
phoneme_masks = []
|
||||
char_ids = []
|
||||
position_ids = []
|
||||
tokenized_cache = {}
|
||||
|
||||
if char2id is None:
|
||||
char2id = {char: idx for idx, char in enumerate(chars)}
|
||||
if use_mask:
|
||||
if char_phoneme_masks is None:
|
||||
char_phoneme_masks = {
|
||||
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
|
||||
for char in char2phonemes
|
||||
}
|
||||
else:
|
||||
full_phoneme_mask = [1] * len(labels)
|
||||
|
||||
for idx in range(len(texts)):
|
||||
text = (truncated_texts if window_size else texts)[idx].lower()
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
cached = tokenized_cache.get(text)
|
||||
if cached is None:
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(
|
||||
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
|
||||
)
|
||||
if len(tokens) <= max_len - 2:
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
cached = {
|
||||
"is_short": True,
|
||||
"tokens": tokens,
|
||||
"text2token": text2token,
|
||||
"token2text": token2text,
|
||||
"input_id": shared_input_id,
|
||||
"token_type_id": shared_token_type_id,
|
||||
"attention_mask": shared_attention_mask,
|
||||
}
|
||||
else:
|
||||
cached = {
|
||||
"is_short": False,
|
||||
"tokens": tokens,
|
||||
"text2token": text2token,
|
||||
"token2text": token2text,
|
||||
}
|
||||
tokenized_cache[text] = cached
|
||||
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
if cached["is_short"]:
|
||||
text_for_query = text
|
||||
query_id_for_query = query_id
|
||||
text2token_for_query = cached["text2token"]
|
||||
input_id = cached["input_id"]
|
||||
token_type_id = cached["token_type_id"]
|
||||
attention_mask = cached["attention_mask"]
|
||||
else:
|
||||
(
|
||||
text_for_query,
|
||||
query_id_for_query,
|
||||
tokens_for_query,
|
||||
text2token_for_query,
|
||||
_token2text_for_query,
|
||||
) = _truncate(
|
||||
max_len=max_len,
|
||||
text=text,
|
||||
query_id=query_id,
|
||||
tokens=cached["tokens"],
|
||||
text2token=cached["text2token"],
|
||||
token2text=cached["token2text"],
|
||||
)
|
||||
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = (
|
||||
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
|
||||
)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||
query_char = text_for_query[query_id_for_query]
|
||||
if use_mask:
|
||||
phoneme_mask = char_phoneme_masks[query_char]
|
||||
else:
|
||||
phoneme_mask = full_phoneme_mask
|
||||
char_id = char2id[query_char]
|
||||
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
@ -83,10 +141,15 @@ def prepare_onnx_input(
|
||||
char_ids.append(char_id)
|
||||
position_ids.append(position_id)
|
||||
|
||||
max_token_length = max(len(seq) for seq in input_ids)
|
||||
|
||||
def _pad_sequences(sequences, pad_value=0):
|
||||
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
|
||||
|
||||
outputs = {
|
||||
"input_ids": np.array(input_ids).astype(np.int64),
|
||||
"token_type_ids": np.array(token_type_ids).astype(np.int64),
|
||||
"attention_masks": np.array(attention_masks).astype(np.int64),
|
||||
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
|
||||
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
|
||||
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
|
||||
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
||||
"char_ids": np.array(char_ids).astype(np.int64),
|
||||
"position_ids": np.array(position_ids).astype(np.int64),
|
||||
|
||||
@ -8,6 +8,7 @@ from pypinyin.core import Pinyin, Style
|
||||
from pypinyin.seg.simpleseg import simple_seg
|
||||
from pypinyin.converter import UltimateConverter
|
||||
from pypinyin.contrib.tone_convert import to_tone
|
||||
from .cuda_api import G2PWCudaConverter
|
||||
from .onnx_api import G2PWOnnxConverter
|
||||
|
||||
current_file_path = os.path.dirname(__file__)
|
||||
@ -27,12 +28,36 @@ class G2PWPinyin(Pinyin):
|
||||
tone_sandhi=False,
|
||||
**kwargs,
|
||||
):
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
backend = os.environ.get("GPTSOVITS_G2PW_BACKEND", "cuda").strip().lower()
|
||||
last_error = None
|
||||
self._g2pw = None
|
||||
if backend in {"cuda", "auto"}:
|
||||
try:
|
||||
self._g2pw = G2PWCudaConverter(
|
||||
model_dir=model_dir,
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
strict_mode = os.environ.get("GPTSOVITS_G2PW_CUDA_STRICT", "0").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"on",
|
||||
}
|
||||
if backend == "cuda" and strict_mode:
|
||||
raise
|
||||
if self._g2pw is None:
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
if last_error is not None:
|
||||
print(f"[g2pw] cuda backend unavailable, fallback to onnx: {last_error}")
|
||||
self._converter = Converter(
|
||||
self._g2pw,
|
||||
v_to_u=v_to_u,
|
||||
|
||||
183
GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp
Normal file
183
GPT_SoVITS/text/g2pw/g2pw_cuda_bridge.cpp
Normal file
@ -0,0 +1,183 @@
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "g2pw/runtime.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct G2PWRuntimeHandle {
|
||||
std::unique_ptr<g2pw::Runtime> runtime;
|
||||
std::string last_error;
|
||||
int num_labels = 0;
|
||||
};
|
||||
|
||||
void SetError(G2PWRuntimeHandle* handle, const g2pw::Status& status) {
|
||||
if (handle == nullptr) {
|
||||
return;
|
||||
}
|
||||
handle->last_error = status.message;
|
||||
}
|
||||
|
||||
g2pw::RuntimeConfig BuildConfig(
|
||||
int device_ordinal,
|
||||
int max_batch_size,
|
||||
int max_seq_len,
|
||||
int full_graph_cache_limit,
|
||||
int tail_graph_cache_limit,
|
||||
int allow_tensor_cores,
|
||||
int use_cublaslt_bias_epilogue,
|
||||
int enable_profiling,
|
||||
int enable_cuda_graph,
|
||||
int dump_graph_cache_stats,
|
||||
int gemm_precision) {
|
||||
g2pw::RuntimeConfig config{};
|
||||
config.device_ordinal = device_ordinal;
|
||||
config.max_batch_size = max_batch_size;
|
||||
config.max_seq_len = max_seq_len;
|
||||
config.full_graph_cache_limit = full_graph_cache_limit;
|
||||
config.tail_graph_cache_limit = tail_graph_cache_limit;
|
||||
config.allow_tensor_cores = allow_tensor_cores != 0;
|
||||
config.use_cublaslt_bias_epilogue = use_cublaslt_bias_epilogue != 0;
|
||||
config.enable_profiling = enable_profiling != 0;
|
||||
config.enable_cuda_graph = enable_cuda_graph != 0;
|
||||
config.dump_graph_cache_stats = dump_graph_cache_stats != 0;
|
||||
switch (gemm_precision) {
|
||||
case 1:
|
||||
config.gemm_precision = g2pw::GemmPrecision::kFp16;
|
||||
break;
|
||||
case 2:
|
||||
config.gemm_precision = g2pw::GemmPrecision::kBf16;
|
||||
break;
|
||||
default:
|
||||
config.gemm_precision = g2pw::GemmPrecision::kFp32;
|
||||
break;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
|
||||
void* g2pw_runtime_create(
|
||||
const char* manifest_path,
|
||||
const char* binary_path,
|
||||
int device_ordinal,
|
||||
int max_batch_size,
|
||||
int max_seq_len,
|
||||
int full_graph_cache_limit,
|
||||
int tail_graph_cache_limit,
|
||||
int allow_tensor_cores,
|
||||
int use_cublaslt_bias_epilogue,
|
||||
int enable_profiling,
|
||||
int enable_cuda_graph,
|
||||
int dump_graph_cache_stats,
|
||||
int gemm_precision) {
|
||||
auto* handle = new G2PWRuntimeHandle();
|
||||
try {
|
||||
if (manifest_path == nullptr || binary_path == nullptr) {
|
||||
handle->last_error = "manifest_path and binary_path must be non-null";
|
||||
return handle;
|
||||
}
|
||||
g2pw::RuntimeConfig config = BuildConfig(
|
||||
device_ordinal,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
full_graph_cache_limit,
|
||||
tail_graph_cache_limit,
|
||||
allow_tensor_cores,
|
||||
use_cublaslt_bias_epilogue,
|
||||
enable_profiling,
|
||||
enable_cuda_graph,
|
||||
dump_graph_cache_stats,
|
||||
gemm_precision);
|
||||
g2pw::Status status = g2pw::Runtime::Create(
|
||||
config,
|
||||
std::string(manifest_path),
|
||||
std::string(binary_path),
|
||||
&handle->runtime);
|
||||
if (!status.ok()) {
|
||||
SetError(handle, status);
|
||||
return handle;
|
||||
}
|
||||
handle->num_labels = handle->runtime != nullptr ? handle->runtime->weights().manifest().num_labels : 0;
|
||||
handle->last_error.clear();
|
||||
return handle;
|
||||
} catch (const std::exception& exc) {
|
||||
handle->last_error = exc.what();
|
||||
return handle;
|
||||
} catch (...) {
|
||||
handle->last_error = "unknown exception";
|
||||
return handle;
|
||||
}
|
||||
}
|
||||
|
||||
void g2pw_runtime_destroy(void* raw_handle) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
delete handle;
|
||||
}
|
||||
|
||||
const char* g2pw_runtime_last_error(void* raw_handle) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
if (handle == nullptr) {
|
||||
return "invalid runtime handle";
|
||||
}
|
||||
return handle->last_error.c_str();
|
||||
}
|
||||
|
||||
int g2pw_runtime_num_labels(void* raw_handle) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
if (handle == nullptr || handle->runtime == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return handle->num_labels;
|
||||
}
|
||||
|
||||
int g2pw_runtime_run(
|
||||
void* raw_handle,
|
||||
const std::int64_t* input_ids,
|
||||
const std::int64_t* token_type_ids,
|
||||
const std::int64_t* attention_mask,
|
||||
const float* phoneme_mask,
|
||||
const std::int64_t* char_ids,
|
||||
const std::int64_t* position_ids,
|
||||
std::int32_t batch_size,
|
||||
std::int32_t seq_len,
|
||||
float* probs) {
|
||||
auto* handle = static_cast<G2PWRuntimeHandle*>(raw_handle);
|
||||
if (handle == nullptr || handle->runtime == nullptr) {
|
||||
return static_cast<int>(g2pw::StatusCode::kInvalidArgument);
|
||||
}
|
||||
try {
|
||||
g2pw::InferenceInputs inputs{};
|
||||
inputs.input_ids = input_ids;
|
||||
inputs.token_type_ids = token_type_ids;
|
||||
inputs.attention_mask = attention_mask;
|
||||
inputs.phoneme_mask = phoneme_mask;
|
||||
inputs.char_ids = char_ids;
|
||||
inputs.position_ids = position_ids;
|
||||
inputs.batch_size = batch_size;
|
||||
inputs.seq_len = seq_len;
|
||||
|
||||
g2pw::InferenceOutputs outputs{};
|
||||
outputs.probs = probs;
|
||||
|
||||
const g2pw::Status status = handle->runtime->Run(inputs, outputs);
|
||||
if (!status.ok()) {
|
||||
SetError(handle, status);
|
||||
return static_cast<int>(status.code);
|
||||
}
|
||||
handle->last_error.clear();
|
||||
return static_cast<int>(g2pw::StatusCode::kOk);
|
||||
} catch (const std::exception& exc) {
|
||||
handle->last_error = exc.what();
|
||||
return static_cast<int>(g2pw::StatusCode::kInternalError);
|
||||
} catch (...) {
|
||||
handle->last_error = "unknown exception";
|
||||
return static_cast<int>(g2pw::StatusCode::kInternalError);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
import zipfile
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@ -10,7 +11,6 @@ from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import requests
|
||||
import torch
|
||||
from opencc import OpenCC
|
||||
from pypinyin import Style, pinyin
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
@ -22,9 +22,8 @@ from .utils import load_config
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
try:
|
||||
onnxruntime.preload_dlls()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
# traceback.print_exc()
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
model_version = "1.1"
|
||||
@ -55,6 +54,41 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]:
|
||||
for candidate_dir in candidate_dirs:
|
||||
if not candidate_dir:
|
||||
continue
|
||||
json_path = os.path.join(candidate_dir, filename)
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, "r", encoding="utf-8") as fr:
|
||||
return json.load(fr)
|
||||
raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}")
|
||||
|
||||
|
||||
def _find_first_existing_file(*paths: str) -> str:
|
||||
for path in paths:
|
||||
if path and os.path.exists(path):
|
||||
return path
|
||||
raise FileNotFoundError(f"Files not found: {paths}")
|
||||
|
||||
|
||||
def _resolve_tokenizer_source(model_source: str | None) -> str:
|
||||
candidate_paths = []
|
||||
if model_source:
|
||||
candidate_paths.append(model_source)
|
||||
repo_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
candidate_paths.extend(
|
||||
[
|
||||
os.path.join(repo_root, "pretrained_models", "g2pw-chinese"),
|
||||
os.path.join(repo_root, "pretrained_models", "chinese-roberta-wwm-ext-large"),
|
||||
]
|
||||
)
|
||||
for candidate in candidate_paths:
|
||||
if candidate and os.path.exists(candidate):
|
||||
return candidate
|
||||
return model_source or "bert-base-chinese"
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
@ -62,7 +96,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
|
||||
print("Downloading g2pw model...")
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"
|
||||
with requests.get(modelscope_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(zip_dir, "wb") as f:
|
||||
@ -79,7 +113,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
return model_dir
|
||||
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
class _G2PWBaseOnnxConverter:
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
@ -87,33 +121,16 @@ class G2PWOnnxConverter:
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
uncompress_path = download_and_decompress(model_dir)
|
||||
self.model_dir = download_and_decompress(model_dir)
|
||||
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.model_source = _resolve_tokenizer_source(model_source if model_source else self.config.model_source)
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source, local_files_only=True)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
|
||||
|
||||
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
|
||||
self.polyphonic_chars = [
|
||||
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
@ -149,31 +166,47 @@ class G2PWOnnxConverter:
|
||||
)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
self.char2id = {char: idx for idx, char in enumerate(self.chars)}
|
||||
self.char_phoneme_masks = (
|
||||
{
|
||||
char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))]
|
||||
for char in self.char2phonemes
|
||||
}
|
||||
if self.config.use_mask
|
||||
else None
|
||||
)
|
||||
|
||||
self.polyphonic_chars_new = set(self.chars)
|
||||
for char in self.non_polyphonic:
|
||||
if char in self.polyphonic_chars_new:
|
||||
self.polyphonic_chars_new.remove(char)
|
||||
self.polyphonic_chars_new.discard(char)
|
||||
|
||||
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
|
||||
for char in self.non_monophonic:
|
||||
if char in self.monophonic_chars_dict:
|
||||
self.monophonic_chars_dict.pop(char)
|
||||
self.monophonic_chars_dict.pop(char, None)
|
||||
|
||||
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
|
||||
default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel"))
|
||||
candidate_asset_dirs = [self.model_dir, default_asset_dir]
|
||||
self.bopomofo_convert_dict = _load_json_from_candidates(
|
||||
"bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs
|
||||
)
|
||||
self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs)
|
||||
|
||||
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
"bopomofo": lambda x: x,
|
||||
"pinyin": self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC("s2tw")
|
||||
self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"y",
|
||||
"on",
|
||||
}
|
||||
# 聚焦到多音字附近上下文,默认左右各16字;设为0表示关闭裁剪(整句)。
|
||||
self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16")))
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||
tone = bopomofo[-1]
|
||||
@ -181,11 +214,14 @@ class G2PWOnnxConverter:
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
else:
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
|
||||
def __call__(self, sentences: List[str]) -> List[List[str]]:
|
||||
results, _profile = self.predict_sentences_with_profile(sentences)
|
||||
return results
|
||||
|
||||
def predict_sentences_with_profile(self, sentences: List[str]) -> Tuple[List[List[str]], Dict[str, float]]:
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
@ -197,51 +233,202 @@ class G2PWOnnxConverter:
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
return partial_results, {}
|
||||
|
||||
onnx_input = prepare_onnx_input(
|
||||
model_input = prepare_onnx_input(
|
||||
tokenizer=self.tokenizer,
|
||||
labels=self.labels,
|
||||
char2phonemes=self.char2phonemes,
|
||||
chars=self.chars,
|
||||
texts=texts,
|
||||
query_ids=query_ids,
|
||||
query_ids=model_query_ids,
|
||||
use_mask=self.config.use_mask,
|
||||
window_size=None,
|
||||
char2id=self.char2id,
|
||||
char_phoneme_masks=self.char_phoneme_masks,
|
||||
)
|
||||
|
||||
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
|
||||
if not model_input:
|
||||
return partial_results, {}
|
||||
|
||||
predict_profile: Dict[str, float] = {}
|
||||
if self.enable_sentence_dedup:
|
||||
preds, _confidences, predict_profile = self._predict_with_sentence_dedup_profiled(
|
||||
model_input=model_input,
|
||||
texts=texts,
|
||||
)
|
||||
else:
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
preds, _confidences, predict_profile = self._predict_with_profile(model_input=model_input)
|
||||
else:
|
||||
predict_started = time.perf_counter()
|
||||
preds, _confidences = self._predict(model_input=model_input)
|
||||
predict_profile["g2pw_predict_ms"] = float((time.perf_counter() - predict_started) * 1000.0)
|
||||
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(" ")[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
|
||||
results[sent_id][query_id] = self.style_convert_func(pred)
|
||||
|
||||
return results
|
||||
return results, predict_profile
|
||||
|
||||
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]:
|
||||
texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||
sent_s = tranditional_to_simplified(sent)
|
||||
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
polyphonic_indices: List[int] = []
|
||||
for i, char in enumerate(sent):
|
||||
if char in self.polyphonic_chars_new:
|
||||
texts.append(sent)
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
polyphonic_indices.append(i)
|
||||
elif char in self.monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
else:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
|
||||
if polyphonic_indices:
|
||||
if self.polyphonic_context_chars > 0:
|
||||
left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars)
|
||||
right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1)
|
||||
sent_for_predict = sent[left:right]
|
||||
query_offset = left
|
||||
else:
|
||||
sent_for_predict = sent
|
||||
query_offset = 0
|
||||
|
||||
for index in polyphonic_indices:
|
||||
texts.append(sent_for_predict)
|
||||
model_query_ids.append(index - query_offset)
|
||||
result_query_ids.append(index)
|
||||
sent_ids.append(sent_id)
|
||||
|
||||
partial_results.append(partial_result)
|
||||
return texts, query_ids, sent_ids, partial_results
|
||||
return texts, model_query_ids, result_query_ids, sent_ids, partial_results
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _predict_with_sentence_dedup(
|
||||
self, model_input: Dict[str, Any], texts: List[str]
|
||||
) -> Tuple[List[str], List[float]]:
|
||||
if len(texts) <= 1:
|
||||
return self._predict(model_input=model_input)
|
||||
|
||||
grouped_indices: Dict[str, List[int]] = {}
|
||||
for idx, text in enumerate(texts):
|
||||
grouped_indices.setdefault(text, []).append(idx)
|
||||
|
||||
if all(len(indices) == 1 for indices in grouped_indices.values()):
|
||||
return self._predict(model_input=model_input)
|
||||
|
||||
preds: List[str] = [""] * len(texts)
|
||||
confidences: List[float] = [0.0] * len(texts)
|
||||
for indices in grouped_indices.values():
|
||||
group_input = {name: value[indices] for name, value in model_input.items()}
|
||||
if len(indices) > 1:
|
||||
for name in ("input_ids", "token_type_ids", "attention_masks"):
|
||||
group_input[name] = group_input[name][:1]
|
||||
|
||||
group_preds, group_confidences = self._predict(model_input=group_input)
|
||||
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
|
||||
preds[output_idx] = pred
|
||||
confidences[output_idx] = confidence
|
||||
|
||||
return preds, confidences
|
||||
|
||||
def _predict_with_sentence_dedup_profiled(
|
||||
self,
|
||||
model_input: Dict[str, Any],
|
||||
texts: List[str],
|
||||
) -> Tuple[List[str], List[float], Dict[str, float]]:
|
||||
if len(texts) <= 1:
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
return self._predict_with_profile(model_input=model_input)
|
||||
predict_started = time.perf_counter()
|
||||
preds, confidences = self._predict(model_input=model_input)
|
||||
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
|
||||
|
||||
grouped_indices: Dict[str, List[int]] = {}
|
||||
for idx, text in enumerate(texts):
|
||||
grouped_indices.setdefault(text, []).append(idx)
|
||||
|
||||
if all(len(indices) == 1 for indices in grouped_indices.values()):
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
return self._predict_with_profile(model_input=model_input)
|
||||
predict_started = time.perf_counter()
|
||||
preds, confidences = self._predict(model_input=model_input)
|
||||
return preds, confidences, {"g2pw_predict_ms": float((time.perf_counter() - predict_started) * 1000.0)}
|
||||
|
||||
preds: List[str] = [""] * len(texts)
|
||||
confidences: List[float] = [0.0] * len(texts)
|
||||
merged_profile: Dict[str, float] = {}
|
||||
for indices in grouped_indices.values():
|
||||
group_input = {name: value[indices] for name, value in model_input.items()}
|
||||
if len(indices) > 1:
|
||||
for name in ("input_ids", "token_type_ids", "attention_masks"):
|
||||
group_input[name] = group_input[name][:1]
|
||||
if hasattr(self, "_predict_with_profile"):
|
||||
group_preds, group_confidences, group_profile = self._predict_with_profile(model_input=group_input)
|
||||
for key, value in dict(group_profile or {}).items():
|
||||
merged_profile[key] = float(merged_profile.get(key, 0.0)) + float(value)
|
||||
else:
|
||||
predict_started = time.perf_counter()
|
||||
group_preds, group_confidences = self._predict(model_input=group_input)
|
||||
merged_profile["g2pw_predict_ms"] = float(
|
||||
merged_profile.get("g2pw_predict_ms", 0.0) + (time.perf_counter() - predict_started) * 1000.0
|
||||
)
|
||||
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
|
||||
preds[output_idx] = pred
|
||||
confidences[output_idx] = confidence
|
||||
return preds, confidences, merged_profile
|
||||
|
||||
|
||||
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
style=style,
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
|
||||
onnx_path = _find_first_existing_file(
|
||||
os.path.join(self.model_dir, "g2pW.onnx"),
|
||||
os.path.join(self.model_dir, "g2pw.onnx"),
|
||||
)
|
||||
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
self.session_g2pw = onnxruntime.InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
self.session_g2pw = onnxruntime.InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels)
|
||||
|
||||
285
api_v2.py
285
api_v2.py
@ -39,8 +39,8 @@ POST:
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
|
||||
"super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
@ -79,7 +79,7 @@ endpoint: `/set_gpt_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
```
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
@ -92,7 +92,7 @@ endpoint: `/set_sovits_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
|
||||
```
|
||||
|
||||
RESP:
|
||||
@ -104,27 +104,22 @@ RESP:
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator, Union
|
||||
from typing import Union
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import wave
|
||||
import signal
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
from io import BytesIO
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
|
||||
# print(sys.path)
|
||||
i18n = I18nAuto()
|
||||
@ -147,6 +142,14 @@ if config_path in [None, ""]:
|
||||
tts_config = TTS_Config(config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
tts_engine = UnifiedTTSEngine(
|
||||
tts_pipeline,
|
||||
cut_method_names=cut_method_names,
|
||||
control_callbacks=RuntimeControlCallbacks(
|
||||
restart=lambda: os.execl(sys.executable, sys.executable, *argv),
|
||||
exit=lambda: os.kill(os.getpid(), signal.SIGTERM),
|
||||
),
|
||||
)
|
||||
|
||||
APP = FastAPI()
|
||||
|
||||
@ -178,168 +181,8 @@ class TTS_Request(BaseModel):
|
||||
min_chunk_length: int = 16
|
||||
|
||||
|
||||
def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
# Author: AkagawaTsurunaki
|
||||
# Issue:
|
||||
# Stack overflow probabilistically occurs
|
||||
# when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
|
||||
# using the Python library `soundfile`
|
||||
# Note:
|
||||
# This is an issue related to `libsndfile`, not this project itself.
|
||||
# It happens when you generate a large audio tensor (about 499804 frames in my PC)
|
||||
# and try to convert it to an ogg file.
|
||||
# Related:
|
||||
# https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
|
||||
# https://github.com/libsndfile/libsndfile/issues/1023
|
||||
# https://github.com/bastibe/python-soundfile/issues/396
|
||||
# Suggestion:
|
||||
# Or split the whole audio data into smaller audio segment to avoid stack overflow?
|
||||
|
||||
def handle_pack_ogg():
|
||||
with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
|
||||
|
||||
# See: https://docs.python.org/3/library/threading.html
|
||||
# The stack size of this thread is at least 32768
|
||||
# If stack overflow error still occurs, just modify the `stack_size`.
|
||||
# stack_size = n * 4096, where n should be a positive integer.
|
||||
# Here we chose n = 4096.
|
||||
stack_size = 4096 * 4096
|
||||
try:
|
||||
threading.stack_size(stack_size)
|
||||
pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
|
||||
pack_ogg_thread.start()
|
||||
pack_ogg_thread.join()
|
||||
except RuntimeError as e:
|
||||
# If changing the thread stack size is unsupported, a RuntimeError is raised.
|
||||
print("RuntimeError: {}".format(e))
|
||||
print("Changing the thread stack size is unsupported.")
|
||||
except ValueError as e:
|
||||
# If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
|
||||
print("ValueError: {}".format(e))
|
||||
print("The specified stack size is invalid.")
|
||||
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer.write(data.tobytes())
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
io_buffer = BytesIO()
|
||||
sf.write(io_buffer, data, rate, format="wav")
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int):
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-f",
|
||||
"s16le", # 输入16位有符号小端整数PCM
|
||||
"-ar",
|
||||
str(rate), # 设置采样率
|
||||
"-ac",
|
||||
"1", # 单声道
|
||||
"-i",
|
||||
"pipe:0", # 从管道读取输入
|
||||
"-c:a",
|
||||
"aac", # 音频编码器为AAC
|
||||
"-b:a",
|
||||
"192k", # 比特率
|
||||
"-vn", # 不包含视频
|
||||
"-f",
|
||||
"adts", # 输出AAC数据流格式
|
||||
"pipe:1", # 将输出写入管道
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
io_buffer.write(out)
|
||||
return io_buffer
|
||||
|
||||
|
||||
def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str):
|
||||
if media_type == "ogg":
|
||||
io_buffer = pack_ogg(io_buffer, data, rate)
|
||||
elif media_type == "aac":
|
||||
io_buffer = pack_aac(io_buffer, data, rate)
|
||||
elif media_type == "wav":
|
||||
io_buffer = pack_wav(io_buffer, data, rate)
|
||||
else:
|
||||
io_buffer = pack_raw(io_buffer, data, rate)
|
||||
io_buffer.seek(0)
|
||||
return io_buffer
|
||||
|
||||
|
||||
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
|
||||
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
|
||||
# This will create a wave header then append the frame input
|
||||
# It should be first on a streaming wav file
|
||||
# Other frames better should not have it (else you will hear some artifacts each chunk start)
|
||||
wav_buf = BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
|
||||
wav_buf.seek(0)
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
def handle_control(command: str):
|
||||
if command == "restart":
|
||||
os.execl(sys.executable, sys.executable, *argv)
|
||||
elif command == "exit":
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
|
||||
|
||||
def check_params(req: dict):
|
||||
text: str = req.get("text", "")
|
||||
text_lang: str = req.get("text_lang", "")
|
||||
ref_audio_path: str = req.get("ref_audio_path", "")
|
||||
streaming_mode: bool = req.get("streaming_mode", False)
|
||||
media_type: str = req.get("media_type", "wav")
|
||||
prompt_lang: str = req.get("prompt_lang", "")
|
||||
text_split_method: str = req.get("text_split_method", "cut5")
|
||||
|
||||
if ref_audio_path in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||
if text in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||
if text_lang in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
|
||||
elif text_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"},
|
||||
)
|
||||
if prompt_lang in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
||||
elif prompt_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"},
|
||||
)
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
|
||||
# elif media_type == "ogg" and not streaming_mode:
|
||||
# return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||
|
||||
if text_split_method not in cut_method_names:
|
||||
return JSONResponse(
|
||||
status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}
|
||||
)
|
||||
|
||||
return None
|
||||
def _lower_or_none(value: str | None) -> str | None:
|
||||
return value.lower() if isinstance(value, str) else value
|
||||
|
||||
|
||||
async def tts_handle(req: dict):
|
||||
@ -368,7 +211,7 @@ async def tts_handle(req: dict):
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
@ -377,70 +220,11 @@ async def tts_handle(req: dict):
|
||||
StreamingResponse: audio stream response.
|
||||
"""
|
||||
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
return_fragment = req.get("return_fragment", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode == 0:
|
||||
streaming_mode = False
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 1:
|
||||
streaming_mode = False
|
||||
return_fragment = True
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 2:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = False
|
||||
elif streaming_mode == 3:
|
||||
streaming_mode = True
|
||||
return_fragment = False
|
||||
fixed_length_chunk = True
|
||||
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"message": f"the value of streaming_mode must be 0, 1, 2, 3(int) or true/false(bool)"})
|
||||
|
||||
req["streaming_mode"] = streaming_mode
|
||||
req["return_fragment"] = return_fragment
|
||||
req["fixed_length_chunk"] = fixed_length_chunk
|
||||
|
||||
print(f"{streaming_mode} {return_fragment} {fixed_length_chunk}")
|
||||
|
||||
streaming_mode = streaming_mode or return_fragment
|
||||
|
||||
|
||||
try:
|
||||
tts_generator = tts_pipeline.run(req)
|
||||
|
||||
if streaming_mode:
|
||||
|
||||
def streaming_generator(tts_generator: Generator, media_type: str):
|
||||
if_frist_chunk = True
|
||||
for sr, chunk in tts_generator:
|
||||
if if_frist_chunk and media_type == "wav":
|
||||
yield wave_header_chunk(sample_rate=sr)
|
||||
media_type = "raw"
|
||||
if_frist_chunk = False
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(
|
||||
streaming_generator(
|
||||
tts_generator,
|
||||
media_type,
|
||||
),
|
||||
media_type=f"audio/{media_type}",
|
||||
)
|
||||
|
||||
else:
|
||||
sr, audio_data = next(tts_generator)
|
||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||
return Response(audio_data, media_type=f"audio/{media_type}")
|
||||
result = await tts_engine.run_direct_tts_async(req)
|
||||
if result.streaming:
|
||||
return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}")
|
||||
return Response(result.audio_bytes, media_type=f"audio/{result.media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
|
||||
|
||||
@ -449,7 +233,11 @@ async def tts_handle(req: dict):
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||
handle_control(command)
|
||||
try:
|
||||
tts_engine.handle_control(command)
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)})
|
||||
|
||||
|
||||
@APP.get("/tts")
|
||||
@ -481,11 +269,11 @@ async def tts_get_endpoint(
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
"text_lang": text_lang.lower(),
|
||||
"text_lang": _lower_or_none(text_lang),
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_text": prompt_text,
|
||||
"prompt_lang": prompt_lang.lower(),
|
||||
"prompt_lang": _lower_or_none(prompt_lang),
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
@ -517,10 +305,10 @@ async def tts_post_endpoint(request: TTS_Request):
|
||||
@APP.get("/set_refer_audio")
|
||||
async def set_refer_aduio(refer_audio_path: str = None):
|
||||
try:
|
||||
tts_pipeline.set_ref_audio(refer_audio_path)
|
||||
payload = tts_engine.set_refer_audio(refer_audio_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
# @APP.post("/set_refer_audio")
|
||||
@ -545,24 +333,19 @@ async def set_refer_aduio(refer_audio_path: str = None):
|
||||
@APP.get("/set_gpt_weights")
|
||||
async def set_gpt_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
|
||||
tts_pipeline.init_t2s_weights(weights_path)
|
||||
payload = tts_engine.set_gpt_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
@APP.get("/set_sovits_weights")
|
||||
async def set_sovits_weights(weights_path: str = None):
|
||||
try:
|
||||
if weights_path in ["", None]:
|
||||
return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
|
||||
tts_pipeline.init_vits_weights(weights_path)
|
||||
payload = tts_engine.set_sovits_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
443
api_v3.py
Normal file
443
api_v3.py
Normal file
@ -0,0 +1,443 @@
|
||||
"""
|
||||
# WebAPI文档
|
||||
|
||||
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
|
||||
|
||||
## 执行参数:
|
||||
`-a` - `绑定地址, 默认"127.0.0.1"`
|
||||
`-p` - `绑定端口, 默认9880`
|
||||
`-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
|
||||
|
||||
## 调用:
|
||||
|
||||
### 推理
|
||||
|
||||
endpoint: `/tts`
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
|
||||
```
|
||||
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"text": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 15, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. 仅 v3/v4 vocoder 路径使用;当前 v2/v2ProPlus 主线可忽略。
|
||||
"super_sampling": False, # bool. 仅 v3/v4 路径使用;不属于当前 v2/v2ProPlus 正式支持目标。
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
}
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: 直接返回 wav 音频流, http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
### 命令控制
|
||||
|
||||
endpoint: `/control`
|
||||
|
||||
command:
|
||||
"restart": 重新运行
|
||||
"exit": 结束运行
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/control?command=restart
|
||||
```
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"command": "restart"
|
||||
}
|
||||
```
|
||||
|
||||
RESP: 无
|
||||
|
||||
|
||||
### 切换GPT模型
|
||||
|
||||
endpoint: `/set_gpt_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
```
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
|
||||
### 切换Sovits模型
|
||||
|
||||
endpoint: `/set_sovits_weights`
|
||||
|
||||
GET:
|
||||
```
|
||||
http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: 返回"success", http code 200
|
||||
失败: 返回包含错误信息的 json, http code 400
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import List, Union
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
|
||||
from runtime_preload import preload_text_runtime_deps
|
||||
|
||||
preload_text_runtime_deps()
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.unified_engine import RuntimeControlCallbacks, UnifiedTTSEngine
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from pydantic import BaseModel
|
||||
|
||||
# print(sys.path)
|
||||
i18n = I18nAuto()
|
||||
cut_method_names = get_cut_method_names()
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
||||
args = parser.parse_args()
|
||||
config_path = args.tts_config
|
||||
# device = args.device
|
||||
port = args.port
|
||||
host = args.bind_addr
|
||||
argv = sys.argv
|
||||
|
||||
if config_path in [None, ""]:
|
||||
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
|
||||
|
||||
tts_config = TTS_Config(config_path)
|
||||
print(tts_config)
|
||||
tts_pipeline = TTS(tts_config)
|
||||
tts_engine = UnifiedTTSEngine(
|
||||
tts_pipeline,
|
||||
cut_method_names=cut_method_names,
|
||||
control_callbacks=RuntimeControlCallbacks(
|
||||
restart=lambda: os.execl(sys.executable, sys.executable, *argv),
|
||||
exit=lambda: os.kill(os.getpid(), signal.SIGTERM),
|
||||
),
|
||||
)
|
||||
|
||||
APP = FastAPI()
|
||||
|
||||
|
||||
class TTS_Request(BaseModel):
|
||||
text: str = None
|
||||
text_lang: str = None
|
||||
ref_audio_path: str = None
|
||||
aux_ref_audio_paths: list = None
|
||||
prompt_lang: str = None
|
||||
prompt_text: str = ""
|
||||
top_k: int = 15
|
||||
top_p: float = 1
|
||||
temperature: float = 1
|
||||
text_split_method: str = "cut5"
|
||||
batch_size: int = 1
|
||||
batch_threshold: float = 0.75
|
||||
split_bucket: bool = True
|
||||
speed_factor: float = 1.0
|
||||
fragment_interval: float = 0.3
|
||||
seed: int = -1
|
||||
media_type: str = "wav"
|
||||
streaming_mode: Union[bool, int] = False
|
||||
parallel_infer: bool = True
|
||||
repetition_penalty: float = 1.35
|
||||
sample_steps: int = 32
|
||||
super_sampling: bool = False
|
||||
overlap_length: int = 2
|
||||
min_chunk_length: int = 16
|
||||
|
||||
|
||||
class Scheduler_Debug_Request_Item(BaseModel):
|
||||
request_id: str | None = None
|
||||
text: str
|
||||
text_lang: str
|
||||
ref_audio_path: str
|
||||
prompt_lang: str
|
||||
prompt_text: str = ""
|
||||
top_k: int = 15
|
||||
top_p: float = 1
|
||||
temperature: float = 1
|
||||
repetition_penalty: float = 1.35
|
||||
early_stop_num: int = -1
|
||||
ready_step: int = 0
|
||||
|
||||
|
||||
class Scheduler_Debug_Request(BaseModel):
|
||||
requests: List[Scheduler_Debug_Request_Item]
|
||||
max_steps: int = 1500
|
||||
seed: int = -1
|
||||
|
||||
|
||||
class Scheduler_Submit_Request(BaseModel):
|
||||
request_id: str | None = None
|
||||
text: str
|
||||
text_lang: str
|
||||
ref_audio_path: str
|
||||
prompt_lang: str
|
||||
prompt_text: str = ""
|
||||
top_k: int = 15
|
||||
top_p: float = 1
|
||||
temperature: float = 1
|
||||
repetition_penalty: float = 1.35
|
||||
early_stop_num: int = -1
|
||||
speed_factor: float = 1.0
|
||||
sample_steps: int = 32
|
||||
media_type: str = "wav"
|
||||
timeout_sec: float = 30.0
|
||||
|
||||
|
||||
def _lower_or_none(value: str | None) -> str | None:
|
||||
return value.lower() if isinstance(value, str) else value
|
||||
|
||||
|
||||
async def tts_scheduler_debug_handle(request: Scheduler_Debug_Request):
|
||||
try:
|
||||
result = await tts_engine.run_scheduler_debug(
|
||||
request_items=[item.dict() for item in request.requests],
|
||||
max_steps=int(request.max_steps),
|
||||
seed=int(request.seed),
|
||||
)
|
||||
return JSONResponse(status_code=200, content=result.payload)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": "scheduler debug failed", "Exception": str(e)},
|
||||
)
|
||||
|
||||
|
||||
async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
try:
|
||||
result = await tts_engine.run_scheduler_submit(request.dict())
|
||||
return Response(result.audio_bytes, media_type=result.media_type, headers=result.headers)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"message": "scheduler submit failed", "Exception": str(e)},
|
||||
)
|
||||
|
||||
|
||||
async def tts_handle(req: dict):
|
||||
"""
|
||||
Text to speech handler.
|
||||
|
||||
Args:
|
||||
req (dict):
|
||||
{
|
||||
"text": "", # str.(required) text to be synthesized
|
||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
||||
"ref_audio_path": "", # str.(required) reference audio path
|
||||
"aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
||||
"top_k": 15, # int. top k sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. only for v3/v4; not part of current v2/v2ProPlus mainline.
|
||||
"streaming_mode": False, # bool or int. return audio chunk by chunk.T he available options are: 0,1,2,3 or True/False (0/False: Disabled | 1/True: Best Quality, Slowest response speed (old version streaming_mode) | 2: Medium Quality, Slow response speed | 3: Lower Quality, Faster response speed )
|
||||
"overlap_length": 2, # int. overlap length of semantic tokens for streaming mode.
|
||||
"min_chunk_length": 16, # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await tts_engine.run_direct_tts_async(req)
|
||||
if result.streaming:
|
||||
return StreamingResponse(result.audio_generator, media_type=f"audio/{result.media_type}")
|
||||
return Response(result.audio_bytes, media_type=f"audio/{result.media_type}")
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)})
|
||||
|
||||
|
||||
@APP.get("/control")
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||
try:
|
||||
tts_engine.handle_control(command)
|
||||
return JSONResponse(status_code=200, content={"message": "success"})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "control failed", "Exception": str(e)})
|
||||
|
||||
|
||||
@APP.get("/tts")
|
||||
async def tts_get_endpoint(
|
||||
text: str = None,
|
||||
text_lang: str = None,
|
||||
ref_audio_path: str = None,
|
||||
aux_ref_audio_paths: list = None,
|
||||
prompt_lang: str = None,
|
||||
prompt_text: str = "",
|
||||
top_k: int = 15,
|
||||
top_p: float = 1,
|
||||
temperature: float = 1,
|
||||
text_split_method: str = "cut5",
|
||||
batch_size: int = 1,
|
||||
batch_threshold: float = 0.75,
|
||||
split_bucket: bool = True,
|
||||
speed_factor: float = 1.0,
|
||||
fragment_interval: float = 0.3,
|
||||
seed: int = -1,
|
||||
media_type: str = "wav",
|
||||
parallel_infer: bool = True,
|
||||
repetition_penalty: float = 1.35,
|
||||
sample_steps: int = 32,
|
||||
super_sampling: bool = False,
|
||||
streaming_mode: Union[bool, int] = False,
|
||||
overlap_length: int = 2,
|
||||
min_chunk_length: int = 16,
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
"text_lang": _lower_or_none(text_lang),
|
||||
"ref_audio_path": ref_audio_path,
|
||||
"aux_ref_audio_paths": aux_ref_audio_paths,
|
||||
"prompt_text": prompt_text,
|
||||
"prompt_lang": _lower_or_none(prompt_lang),
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": text_split_method,
|
||||
"batch_size": int(batch_size),
|
||||
"batch_threshold": float(batch_threshold),
|
||||
"speed_factor": float(speed_factor),
|
||||
"split_bucket": split_bucket,
|
||||
"fragment_interval": fragment_interval,
|
||||
"seed": seed,
|
||||
"media_type": media_type,
|
||||
"streaming_mode": streaming_mode,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": float(repetition_penalty),
|
||||
"sample_steps": int(sample_steps),
|
||||
"super_sampling": super_sampling,
|
||||
"overlap_length": int(overlap_length),
|
||||
"min_chunk_length": int(min_chunk_length),
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
@APP.post("/tts")
|
||||
async def tts_post_endpoint(request: TTS_Request):
|
||||
req = request.dict()
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
@APP.post("/tts_scheduler_debug")
|
||||
async def tts_scheduler_debug_endpoint(request: Scheduler_Debug_Request):
|
||||
return await tts_scheduler_debug_handle(request)
|
||||
|
||||
|
||||
@APP.post("/tts_scheduler_submit")
|
||||
async def tts_scheduler_submit_endpoint(request: Scheduler_Submit_Request):
|
||||
return await tts_scheduler_submit_handle(request)
|
||||
|
||||
|
||||
@APP.get("/tts_scheduler_state")
|
||||
async def tts_scheduler_state_endpoint():
|
||||
return JSONResponse(status_code=200, content=tts_engine.get_runtime_state())
|
||||
|
||||
|
||||
@APP.get("/set_refer_audio")
|
||||
async def set_refer_aduio(refer_audio_path: str = None):
|
||||
try:
|
||||
payload = tts_engine.set_refer_audio(refer_audio_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
# @APP.post("/set_refer_audio")
|
||||
# async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
|
||||
# try:
|
||||
# # 检查文件类型,确保是音频文件
|
||||
# if not audio_file.content_type.startswith("audio/"):
|
||||
# return JSONResponse(status_code=400, content={"message": "file type is not supported"})
|
||||
|
||||
# os.makedirs("uploaded_audio", exist_ok=True)
|
||||
# save_path = os.path.join("uploaded_audio", audio_file.filename)
|
||||
# # 保存音频文件到服务器上的一个目录
|
||||
# with open(save_path , "wb") as buffer:
|
||||
# buffer.write(await audio_file.read())
|
||||
|
||||
# tts_pipeline.set_ref_audio(save_path)
|
||||
# except Exception as e:
|
||||
# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
|
||||
# return JSONResponse(status_code=200, content={"message": "success"})
|
||||
|
||||
|
||||
@APP.get("/set_gpt_weights")
|
||||
async def set_gpt_weights(weights_path: str = None):
|
||||
try:
|
||||
payload = tts_engine.set_gpt_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
@APP.get("/set_sovits_weights")
|
||||
async def set_sovits_weights(weights_path: str = None):
|
||||
try:
|
||||
payload = tts_engine.set_sovits_weights(weights_path)
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)})
|
||||
return JSONResponse(status_code=200, content=payload)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈
|
||||
host = None
|
||||
uvicorn.run(app=APP, host=host, port=port, workers=1)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
1
third_party/g2pw-cu
vendored
Submodule
1
third_party/g2pw-cu
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit a53cf4eed5759f7b5d4563ce6e4b13557e054d98
|
||||
250
tools/bench_api_v3_scheduler_submit.py
Normal file
250
tools/bench_api_v3_scheduler_submit.py
Normal file
@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.")
|
||||
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880")
|
||||
parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit")
|
||||
parser.add_argument("--concurrency", type=int, required=True)
|
||||
parser.add_argument("--timeout-sec", type=float, default=120.0)
|
||||
parser.add_argument("--server-pid", type=int, default=None)
|
||||
parser.add_argument("--poll-interval-sec", type=float, default=0.1)
|
||||
parser.add_argument("--text-lang", type=str, default="zh")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--media-type", type=str, default="wav")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--sample-steps", type=int, default=32)
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
||||
parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
||||
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
||||
wav_paths_all = sorted(args.wav_dir.glob("*.wav"))
|
||||
wav_paths: List[Path] = []
|
||||
for wav_path in wav_paths_all:
|
||||
with wave.open(str(wav_path), "rb") as handle:
|
||||
duration = handle.getnframes() / float(handle.getframerate())
|
||||
if 3.0 <= duration <= 10.0:
|
||||
wav_paths.append(wav_path)
|
||||
if not wav_paths:
|
||||
raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}")
|
||||
text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||
if not text_lines:
|
||||
raise ValueError(f"没有找到有效文本行: {args.text_file}")
|
||||
|
||||
requests: List[Dict[str, Any]] = []
|
||||
for index in range(args.concurrency):
|
||||
wav_path = wav_paths[index % len(wav_paths)]
|
||||
lab_path = wav_path.with_suffix(".lab")
|
||||
if not lab_path.exists():
|
||||
raise FileNotFoundError(f"缺少参考文本: {lab_path}")
|
||||
requests.append(
|
||||
{
|
||||
"request_id": f"bench_{args.concurrency:03d}_{index:03d}",
|
||||
"text": text_lines[index % len(text_lines)],
|
||||
"text_lang": args.text_lang,
|
||||
"ref_audio_path": str(wav_path),
|
||||
"prompt_lang": args.prompt_lang,
|
||||
"prompt_text": lab_path.read_text(encoding="utf-8").strip(),
|
||||
"top_k": int(args.top_k),
|
||||
"top_p": float(args.top_p),
|
||||
"temperature": float(args.temperature),
|
||||
"repetition_penalty": float(args.repetition_penalty),
|
||||
"sample_steps": int(args.sample_steps),
|
||||
"media_type": args.media_type,
|
||||
"timeout_sec": float(args.timeout_sec),
|
||||
}
|
||||
)
|
||||
return requests
|
||||
|
||||
|
||||
class GpuMemoryPoller:
|
||||
def __init__(self, server_pid: Optional[int], interval_sec: float):
|
||||
self.server_pid = server_pid
|
||||
self.interval_sec = interval_sec
|
||||
self._stop = threading.Event()
|
||||
self.samples: List[Dict[str, Any]] = []
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
|
||||
def _query_memory_mb(self) -> Optional[int]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-compute-apps=pid,used_gpu_memory",
|
||||
"--format=csv,noheader,nounits",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
total = 0
|
||||
found = False
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = [item.strip() for item in line.split(",")]
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
used_mb = int(parts[1])
|
||||
except ValueError:
|
||||
continue
|
||||
if self.server_pid is None or pid == self.server_pid:
|
||||
total += used_mb
|
||||
found = True
|
||||
if self.server_pid is None:
|
||||
return total
|
||||
return total if found else 0
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
used_mb = self._query_memory_mb()
|
||||
self.samples.append({"ts": time.time(), "used_mb": used_mb})
|
||||
self._stop.wait(self.interval_sec)
|
||||
|
||||
def start(self) -> None:
|
||||
self.thread = threading.Thread(target=self._run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
if self.thread is not None:
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
valid = [item for item in self.samples if item["used_mb"] is not None]
|
||||
peak = max(valid, key=lambda item: item["used_mb"]) if valid else None
|
||||
first = valid[0] if valid else None
|
||||
last = valid[-1] if valid else None
|
||||
return {
|
||||
"server_pid": self.server_pid,
|
||||
"sample_count": int(len(self.samples)),
|
||||
"start_used_mb": None if first is None else int(first["used_mb"]),
|
||||
"peak_used_mb": None if peak is None else int(peak["used_mb"]),
|
||||
"peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]),
|
||||
"end_used_mb": None if last is None else int(last["used_mb"]),
|
||||
"peak_ts": None if peak is None else float(peak["ts"]),
|
||||
"samples": self.samples,
|
||||
}
|
||||
|
||||
|
||||
async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
response = await client.post(url, json=payload)
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
||||
item = {
|
||||
"request_id": payload["request_id"],
|
||||
"status_code": int(response.status_code),
|
||||
"elapsed_ms": float(elapsed_ms),
|
||||
"content_type": response.headers.get("content-type"),
|
||||
"audio_bytes": int(len(response.content)),
|
||||
"headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")},
|
||||
}
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
item["error_body"] = response.json()
|
||||
except Exception:
|
||||
item["error_body"] = response.text
|
||||
return item
|
||||
except Exception as exc:
|
||||
return {
|
||||
"request_id": payload["request_id"],
|
||||
"status_code": -1,
|
||||
"elapsed_ms": float((time.perf_counter() - started) * 1000.0),
|
||||
"exception": repr(exc),
|
||||
}
|
||||
|
||||
|
||||
async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
|
||||
payloads = load_requests(args)
|
||||
url = args.base_url.rstrip("/") + args.endpoint
|
||||
poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec)
|
||||
|
||||
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
|
||||
timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0)
|
||||
|
||||
started = time.perf_counter()
|
||||
poller.start()
|
||||
try:
|
||||
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
|
||||
results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads])
|
||||
finally:
|
||||
poller.stop()
|
||||
wall_ms = (time.perf_counter() - started) * 1000.0
|
||||
|
||||
ok_results = [item for item in results if item["status_code"] == 200]
|
||||
failed_results = [item for item in results if item["status_code"] != 200]
|
||||
request_total_ms = []
|
||||
worker_total_ms = []
|
||||
for item in ok_results:
|
||||
headers = item.get("headers", {})
|
||||
if "x-request-total-ms" in headers:
|
||||
request_total_ms.append(float(headers["x-request-total-ms"]))
|
||||
if "x-worker-total-ms" in headers:
|
||||
worker_total_ms.append(float(headers["x-worker-total-ms"]))
|
||||
|
||||
return {
|
||||
"concurrency": int(args.concurrency),
|
||||
"server_pid": args.server_pid,
|
||||
"request_count": int(len(payloads)),
|
||||
"wall_ms": float(wall_ms),
|
||||
"success_count": int(len(ok_results)),
|
||||
"failure_count": int(len(failed_results)),
|
||||
"request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None,
|
||||
"request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None,
|
||||
"worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None,
|
||||
"worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None,
|
||||
"gpu_memory": poller.summary(),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
summary = asyncio.run(run_benchmark(args))
|
||||
summary_path = output_dir / "summary.json"
|
||||
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps({
|
||||
"concurrency": summary["concurrency"],
|
||||
"success_count": summary["success_count"],
|
||||
"failure_count": summary["failure_count"],
|
||||
"wall_ms": summary["wall_ms"],
|
||||
"gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"],
|
||||
"request_total_ms_avg": summary["request_total_ms_avg"],
|
||||
"request_total_ms_max": summary["request_total_ms_max"],
|
||||
"summary_path": str(summary_path),
|
||||
}, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
887
tools/t2s_memory_breakdown.py
Normal file
887
tools/t2s_memory_breakdown.py
Normal file
@ -0,0 +1,887 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
|
||||
if str(gpt_sovits_dir) not in sys.path:
|
||||
sys.path.append(str(gpt_sovits_dir))
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
|
||||
SchedulerRequestSpec,
|
||||
T2SRequestState,
|
||||
T2SRunningRequest,
|
||||
_build_decode_batch_from_running,
|
||||
build_prefill_batch,
|
||||
prepare_request_state,
|
||||
run_decode_step_for_running,
|
||||
run_prefill_step,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
|
||||
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
|
||||
parser.add_argument("--request-manifest", type=Path, default=None)
|
||||
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
|
||||
parser.add_argument("--auto-count", type=int, default=4)
|
||||
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
||||
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
||||
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
|
||||
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--text", type=str, default=None)
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
|
||||
parser.add_argument("--text-lang", type=str, default="zh")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--early-stop-num", type=int, default=-1)
|
||||
parser.add_argument("--max-steps", type=int, default=1500)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument("--warmup", action="store_true", default=False)
|
||||
parser.add_argument("--worker-rounds", type=int, default=1)
|
||||
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
|
||||
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def set_seed(seed: int, use_cuda: bool) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _sync_device(device: Any) -> None:
|
||||
try:
|
||||
device_str = str(device)
|
||||
if device_str.startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize(device)
|
||||
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def bytes_to_mb(num_bytes: int) -> float:
|
||||
return float(num_bytes) / (1024.0 * 1024.0)
|
||||
|
||||
|
||||
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
|
||||
if tensor is None:
|
||||
return 0
|
||||
return int(tensor.numel() * tensor.element_size())
|
||||
|
||||
|
||||
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
|
||||
return int(sum(tensor_nbytes(item) for item in items))
|
||||
|
||||
|
||||
def model_nbytes(module: torch.nn.Module) -> int:
|
||||
total = 0
|
||||
for parameter in module.parameters():
|
||||
total += tensor_nbytes(parameter)
|
||||
for buffer in module.buffers():
|
||||
total += tensor_nbytes(buffer)
|
||||
return int(total)
|
||||
|
||||
|
||||
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
|
||||
modules = {
|
||||
"t2s_model": tts.t2s_model,
|
||||
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
|
||||
"vits_model": tts.vits_model,
|
||||
"bert_model": tts.bert_model,
|
||||
"cnhuhbert_model": tts.cnhuhbert_model,
|
||||
"vocoder": tts.vocoder,
|
||||
"sv_model": tts.sv_model,
|
||||
}
|
||||
by_module = {}
|
||||
total_bytes = 0
|
||||
for name, module in modules.items():
|
||||
module_bytes = model_nbytes(module) if module is not None else 0
|
||||
by_module[name] = {
|
||||
"bytes": int(module_bytes),
|
||||
"mb": bytes_to_mb(module_bytes),
|
||||
}
|
||||
total_bytes += module_bytes
|
||||
return {
|
||||
"by_module": by_module,
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
}
|
||||
|
||||
|
||||
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
|
||||
storages: Dict[int, Dict[str, Any]] = {}
|
||||
tensor_views: List[Dict[str, Any]] = []
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
tensor = None
|
||||
if torch.is_tensor(obj):
|
||||
tensor = obj
|
||||
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
|
||||
tensor = obj.data
|
||||
if tensor is None or not tensor.is_cuda:
|
||||
continue
|
||||
storage = tensor.untyped_storage()
|
||||
storage_ptr = int(storage.data_ptr())
|
||||
if storage_ptr not in storages:
|
||||
storages[storage_ptr] = {
|
||||
"storage_ptr": storage_ptr,
|
||||
"storage_bytes": int(storage.nbytes()),
|
||||
"dtype": str(tensor.dtype),
|
||||
"shape": list(tensor.shape),
|
||||
"device": str(tensor.device),
|
||||
}
|
||||
tensor_views.append(
|
||||
{
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
"bytes": tensor_nbytes(tensor),
|
||||
"device": str(tensor.device),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
|
||||
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
|
||||
return {
|
||||
"unique_storage_count": int(len(storage_list)),
|
||||
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
|
||||
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
|
||||
"top_storages": storage_list[:top_k],
|
||||
"top_tensor_views": tensor_views[:top_k],
|
||||
}
|
||||
|
||||
|
||||
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
|
||||
return [
|
||||
SchedulerRequestSpec(
|
||||
request_id="req_000",
|
||||
ref_audio_path=args.ref_audio,
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_lang=args.prompt_lang,
|
||||
text=text,
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
|
||||
if len(wav_paths) < args.auto_count:
|
||||
raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
|
||||
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||
if len(text_lines) < args.auto_count:
|
||||
raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, wav_path in enumerate(wav_paths):
|
||||
lab_path = wav_path.with_suffix(".lab")
|
||||
if not lab_path.exists():
|
||||
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=f"req_{index:03d}",
|
||||
ref_audio_path=wav_path,
|
||||
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
|
||||
prompt_lang="zh",
|
||||
text=text_lines[index],
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
if args.request_manifest is not None:
|
||||
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
|
||||
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, item in enumerate(raw_requests):
|
||||
text = item.get("text")
|
||||
text_file = item.get("text_file")
|
||||
if text is None and text_file is None:
|
||||
raise ValueError(f"request[{index}] must provide text or text_file")
|
||||
if text is None:
|
||||
text = Path(text_file).read_text(encoding="utf-8").strip()
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=item.get("request_id", f"req_{index:03d}"),
|
||||
ref_audio_path=Path(item["ref_audio_path"]),
|
||||
prompt_text=item["prompt_text"],
|
||||
prompt_lang=item.get("prompt_lang", "zh"),
|
||||
text=text,
|
||||
text_lang=item.get("text_lang", "zh"),
|
||||
top_k=int(item.get("top_k", args.top_k)),
|
||||
top_p=float(item.get("top_p", args.top_p)),
|
||||
temperature=float(item.get("temperature", args.temperature)),
|
||||
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
|
||||
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
|
||||
ready_step=int(item.get("ready_step", 0)),
|
||||
)
|
||||
)
|
||||
return specs
|
||||
if args.scenario == "single":
|
||||
return build_single_spec(args)
|
||||
return build_auto_specs(args)
|
||||
|
||||
|
||||
def load_pipeline(config_path: Path) -> TTS:
|
||||
tts_config = TTS_Config(str(config_path))
|
||||
print(tts_config)
|
||||
return TTS(tts_config)
|
||||
|
||||
|
||||
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
|
||||
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
|
||||
return {
|
||||
"allocated_mb": 0.0,
|
||||
"reserved_mb": 0.0,
|
||||
"max_allocated_mb": 0.0,
|
||||
"max_reserved_mb": 0.0,
|
||||
}
|
||||
_sync_device(device)
|
||||
return {
|
||||
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
|
||||
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
|
||||
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
|
||||
}
|
||||
|
||||
|
||||
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
gc.collect()
|
||||
_sync_device(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
before = cuda_mem_snapshot(device)
|
||||
started = time.perf_counter()
|
||||
result = fn()
|
||||
_sync_device(device)
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
||||
after = cuda_mem_snapshot(device)
|
||||
after["elapsed_ms"] = float(elapsed_ms)
|
||||
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
|
||||
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
|
||||
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
|
||||
return result, after
|
||||
|
||||
|
||||
class GlobalPeakRecorder:
|
||||
def __init__(self, device: Any):
|
||||
self.device = device
|
||||
self.checkpoints: List[Dict[str, Any]] = []
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
gc.collect()
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
def record(self, label: str, **extra: Any) -> None:
|
||||
snapshot = cuda_mem_snapshot(self.device)
|
||||
snapshot["label"] = label
|
||||
snapshot.update(extra)
|
||||
self.checkpoints.append(snapshot)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
|
||||
return {
|
||||
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
|
||||
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
|
||||
"peak_label": None if peak is None else peak["label"],
|
||||
"checkpoints": self.checkpoints,
|
||||
}
|
||||
|
||||
|
||||
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
|
||||
per_request = []
|
||||
total = {
|
||||
"phones_bytes": 0,
|
||||
"prompt_phones_bytes": 0,
|
||||
"all_phones_bytes": 0,
|
||||
"all_bert_features_bytes": 0,
|
||||
"prompt_semantic_bytes": 0,
|
||||
"refer_spec_bytes": 0,
|
||||
"raw_audio_bytes": 0,
|
||||
"audio_16k_bytes": 0,
|
||||
}
|
||||
for state in states:
|
||||
spec_audio, audio_16k = state.refer_spec
|
||||
item = {
|
||||
"request_id": state.request_id,
|
||||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||||
"phones_len": int(state.phones.shape[0]),
|
||||
"all_phones_len": int(state.all_phones.shape[0]),
|
||||
"bert_frames": int(state.all_bert_features.shape[-1]),
|
||||
"phones_bytes": tensor_nbytes(state.phones),
|
||||
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
|
||||
"all_phones_bytes": tensor_nbytes(state.all_phones),
|
||||
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
|
||||
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
|
||||
"refer_spec_bytes": tensor_nbytes(spec_audio),
|
||||
"audio_16k_bytes": tensor_nbytes(audio_16k),
|
||||
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
|
||||
}
|
||||
for key in total:
|
||||
total[key] += int(item[key])
|
||||
per_request.append(item)
|
||||
total["total_bytes"] = int(sum(total.values()))
|
||||
total["total_mb"] = bytes_to_mb(total["total_bytes"])
|
||||
return {"per_request": per_request, "total": total}
|
||||
|
||||
|
||||
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
|
||||
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
|
||||
fields = {
|
||||
"x_bytes": tensor_nbytes(active_batch.x),
|
||||
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
|
||||
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
|
||||
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
|
||||
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
|
||||
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
|
||||
"y_sequence_bytes": y_sequence_bytes,
|
||||
}
|
||||
fields["total_bytes"] = int(sum(fields.values()))
|
||||
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
|
||||
fields["batch_size"] = int(len(active_batch.states))
|
||||
fields["max_x_len"] = int(active_batch.x.shape[1])
|
||||
fields["src_len"] = int(active_batch.xy_pos.shape[1])
|
||||
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
|
||||
return fields
|
||||
|
||||
|
||||
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
|
||||
per_request = []
|
||||
total_private_k_bytes = 0
|
||||
total_private_v_bytes = 0
|
||||
total_decode_mask_bytes = 0
|
||||
total_y_sequence_bytes = 0
|
||||
for item in running_requests:
|
||||
k_bytes = tensor_list_nbytes(item.k_cache)
|
||||
v_bytes = tensor_list_nbytes(item.v_cache)
|
||||
mask_bytes = tensor_nbytes(item.decode_attn_mask)
|
||||
y_bytes = tensor_nbytes(item.y_sequence)
|
||||
total_private_k_bytes += k_bytes
|
||||
total_private_v_bytes += v_bytes
|
||||
total_decode_mask_bytes += mask_bytes
|
||||
total_y_sequence_bytes += y_bytes
|
||||
per_request.append(
|
||||
{
|
||||
"request_id": item.state.request_id,
|
||||
"step_idx": int(item.step_idx),
|
||||
"prefix_len": int(item.prefix_len),
|
||||
"history_len": int(item.y_sequence.shape[0]),
|
||||
"kv_len": int(item.k_cache[0].shape[1]),
|
||||
"k_cache_bytes": k_bytes,
|
||||
"v_cache_bytes": v_bytes,
|
||||
"decode_mask_bytes": mask_bytes,
|
||||
"y_sequence_bytes": y_bytes,
|
||||
}
|
||||
)
|
||||
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
|
||||
return {
|
||||
"per_request": per_request,
|
||||
"totals": {
|
||||
"private_k_cache_bytes": int(total_private_k_bytes),
|
||||
"private_v_cache_bytes": int(total_private_v_bytes),
|
||||
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
|
||||
"decode_mask_bytes": int(total_decode_mask_bytes),
|
||||
"y_sequence_bytes": int(total_y_sequence_bytes),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def summarise_decode_batch(
|
||||
xy_pos: torch.Tensor,
|
||||
batched_k_cache: Sequence[torch.Tensor],
|
||||
batched_v_cache: Sequence[torch.Tensor],
|
||||
batched_decode_attn_mask: Optional[torch.Tensor],
|
||||
running_requests: Sequence[T2SRunningRequest],
|
||||
) -> Dict[str, Any]:
|
||||
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
|
||||
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
|
||||
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
|
||||
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
|
||||
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
|
||||
xy_pos_bytes = tensor_nbytes(xy_pos)
|
||||
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
|
||||
return {
|
||||
"batch_size": int(len(running_requests)),
|
||||
"xy_pos_bytes": int(xy_pos_bytes),
|
||||
"batched_k_cache_bytes": int(batched_k_bytes),
|
||||
"batched_v_cache_bytes": int(batched_v_bytes),
|
||||
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
|
||||
"batched_decode_mask_bytes": int(batched_mask_bytes),
|
||||
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
|
||||
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
"xy_pos_shape": list(xy_pos.shape),
|
||||
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
|
||||
"layer_k_cache_shape": list(batched_k_cache[0].shape),
|
||||
}
|
||||
|
||||
|
||||
def summarise_decode_outputs(
|
||||
xy_dec: torch.Tensor,
|
||||
next_k_cache: Sequence[torch.Tensor],
|
||||
next_v_cache: Sequence[torch.Tensor],
|
||||
) -> Dict[str, Any]:
|
||||
xy_dec_bytes = tensor_nbytes(xy_dec)
|
||||
next_k_bytes = tensor_list_nbytes(next_k_cache)
|
||||
next_v_bytes = tensor_list_nbytes(next_v_cache)
|
||||
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
|
||||
return {
|
||||
"xy_dec_bytes": int(xy_dec_bytes),
|
||||
"next_k_cache_bytes": int(next_k_bytes),
|
||||
"next_v_cache_bytes": int(next_v_bytes),
|
||||
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
"xy_dec_shape": list(xy_dec.shape),
|
||||
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
|
||||
}
|
||||
|
||||
|
||||
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
ranking = [
|
||||
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
|
||||
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
|
||||
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
|
||||
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
|
||||
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
|
||||
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
|
||||
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
|
||||
]
|
||||
ranking.sort(key=lambda item: item[1], reverse=True)
|
||||
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
|
||||
|
||||
|
||||
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
|
||||
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
|
||||
phones = state.phones.unsqueeze(0).to(tts.configs.device)
|
||||
audio_fragment = tts.synthesize_audio_request_local(
|
||||
semantic_tokens=semantic_tokens,
|
||||
phones=phones,
|
||||
prompt_semantic=state.prompt_semantic,
|
||||
prompt_phones=state.prompt_phones,
|
||||
refer_spec=state.refer_spec,
|
||||
raw_audio=state.raw_audio,
|
||||
raw_sr=state.raw_sr,
|
||||
speed=1.0,
|
||||
sample_steps=32,
|
||||
)
|
||||
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
|
||||
return tts.audio_postprocess(
|
||||
audio=[[audio_fragment]],
|
||||
sr=int(output_sr),
|
||||
batch_index_list=None,
|
||||
speed_factor=1.0,
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=False,
|
||||
)
|
||||
|
||||
|
||||
def simulate_worker_end_to_end(
|
||||
tts: TTS,
|
||||
specs: Sequence[SchedulerRequestSpec],
|
||||
max_steps: int,
|
||||
rounds: int,
|
||||
grad_mode: str = "default",
|
||||
) -> Dict[str, Any]:
|
||||
device = tts.configs.device
|
||||
recorder = GlobalPeakRecorder(device)
|
||||
recorder.record("after_model_load")
|
||||
|
||||
state_map: Dict[str, T2SRequestState] = {}
|
||||
per_round: List[Dict[str, Any]] = []
|
||||
|
||||
for round_index in range(rounds):
|
||||
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
|
||||
with grad_context():
|
||||
states = [prepare_request_state(tts, spec) for spec in specs]
|
||||
state_map = {state.request_id: state for state in states}
|
||||
recorder.record(
|
||||
"after_prepare_states",
|
||||
round_index=int(round_index),
|
||||
request_count=int(len(states)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
|
||||
pending = list(states)
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
round_events: List[Dict[str, Any]] = []
|
||||
current_tick = 0
|
||||
|
||||
while pending or running_requests:
|
||||
admitted = pending
|
||||
pending = []
|
||||
|
||||
if admitted:
|
||||
recorder.record(
|
||||
"before_prefill",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
admitted_count=int(len(admitted)),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
|
||||
recorder.record(
|
||||
"after_prefill",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
admitted_running_count=int(len(admitted_running)),
|
||||
admitted_finished_count=int(len(admitted_finished)),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
round_events.append(
|
||||
{
|
||||
"tick": int(current_tick),
|
||||
"event": "prefill",
|
||||
"admitted_count": int(len(admitted)),
|
||||
"admitted_running_count": int(len(admitted_running)),
|
||||
"admitted_finished_count": int(len(admitted_finished)),
|
||||
}
|
||||
)
|
||||
for item in admitted_finished:
|
||||
recorder.record(
|
||||
"before_synth_prefill_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||||
recorder.record(
|
||||
"after_synth_prefill_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
sample_rate=int(sample_rate),
|
||||
audio_samples=int(audio_data.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
running_requests.extend(admitted_running)
|
||||
recorder.record(
|
||||
"after_extend_running",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
|
||||
if running_requests:
|
||||
recorder.record(
|
||||
"before_decode",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
running_requests, step_finished = run_decode_step_for_running(
|
||||
tts.t2s_model.model,
|
||||
running_requests,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
recorder.record(
|
||||
"after_decode",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_count=int(len(step_finished)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
round_events.append(
|
||||
{
|
||||
"tick": int(current_tick),
|
||||
"event": "decode",
|
||||
"running_count_after_decode": int(len(running_requests)),
|
||||
"finished_count": int(len(step_finished)),
|
||||
}
|
||||
)
|
||||
for item in step_finished:
|
||||
recorder.record(
|
||||
"before_synth_decode_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||||
recorder.record(
|
||||
"after_synth_decode_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
sample_rate=int(sample_rate),
|
||||
audio_samples=int(audio_data.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
current_tick += 1
|
||||
|
||||
recorder.record(
|
||||
"after_round_complete",
|
||||
round_index=int(round_index),
|
||||
running_count=0,
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
per_round.append(
|
||||
{
|
||||
"round_index": int(round_index),
|
||||
"events": round_events,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"grad_mode": grad_mode,
|
||||
"rounds": per_round,
|
||||
"timeline": recorder.summary(),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tts = load_pipeline(args.config)
|
||||
model = tts.t2s_model.model
|
||||
device = tts.configs.device
|
||||
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
|
||||
set_seed(args.seed, use_cuda)
|
||||
|
||||
specs = load_request_specs(args)
|
||||
if args.early_stop_num == -1:
|
||||
for spec in specs:
|
||||
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
|
||||
|
||||
if args.warmup and specs:
|
||||
warmup_spec = specs[:1]
|
||||
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
torch.cuda.empty_cache()
|
||||
_sync_device(device)
|
||||
|
||||
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
|
||||
request_state_summary = summarise_state_tensors(states)
|
||||
|
||||
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
|
||||
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
|
||||
|
||||
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
|
||||
running_requests, finished_items = prefill_result
|
||||
running_requests_summary = summarise_running_requests(running_requests)
|
||||
finished_after_prefill_summary = [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
}
|
||||
for item in finished_items
|
||||
]
|
||||
|
||||
if not running_requests:
|
||||
raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}")
|
||||
|
||||
decode_batch_result, decode_batch_mem = stage_run(
|
||||
device,
|
||||
lambda: _build_decode_batch_from_running(model, running_requests),
|
||||
)
|
||||
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
|
||||
decode_batch_tensor_summary = summarise_decode_batch(
|
||||
xy_pos,
|
||||
batched_k_cache,
|
||||
batched_v_cache,
|
||||
batched_decode_attn_mask,
|
||||
running_requests,
|
||||
)
|
||||
|
||||
decode_out_result, decode_step_mem = stage_run(
|
||||
device,
|
||||
lambda: model.t2s_transformer.decode_next_token(
|
||||
xy_pos,
|
||||
batched_k_cache,
|
||||
batched_v_cache,
|
||||
batched_decode_attn_mask,
|
||||
),
|
||||
)
|
||||
xy_dec, next_k_cache, next_v_cache = decode_out_result
|
||||
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
|
||||
del active_batch
|
||||
del running_requests
|
||||
del finished_items
|
||||
del xy_pos
|
||||
del batched_k_cache
|
||||
del batched_v_cache
|
||||
del batched_decode_attn_mask
|
||||
del xy_dec
|
||||
del next_k_cache
|
||||
del next_v_cache
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
end_to_end_worker = simulate_worker_end_to_end(
|
||||
tts=tts,
|
||||
specs=specs,
|
||||
max_steps=args.max_steps,
|
||||
rounds=args.worker_rounds,
|
||||
grad_mode=args.worker_grad_mode,
|
||||
)
|
||||
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
|
||||
worker_inference_mode = None
|
||||
if args.compare_worker_grad_modes:
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
worker_inference_mode = simulate_worker_end_to_end(
|
||||
tts=tts,
|
||||
specs=specs,
|
||||
max_steps=args.max_steps,
|
||||
rounds=args.worker_rounds,
|
||||
grad_mode="inference_mode",
|
||||
)
|
||||
|
||||
summary = {
|
||||
"meta": {
|
||||
"scenario": args.scenario if args.request_manifest is None else "manifest",
|
||||
"seed": int(args.seed),
|
||||
"device": str(device),
|
||||
"dtype": str(next(model.parameters()).dtype),
|
||||
"request_count": int(len(specs)),
|
||||
"num_layers": int(model.num_layers),
|
||||
"num_heads": int(model.num_head),
|
||||
"model_dim": int(model.model_dim),
|
||||
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
|
||||
},
|
||||
"loaded_module_weights": build_module_weight_summary(tts),
|
||||
"requests": [
|
||||
{
|
||||
"request_id": spec.request_id,
|
||||
"ref_audio_path": str(spec.ref_audio_path),
|
||||
"prompt_text": spec.prompt_text,
|
||||
"text": spec.text,
|
||||
}
|
||||
for spec in specs
|
||||
],
|
||||
"prepare_stage": {
|
||||
"memory": prepare_mem,
|
||||
"request_state": request_state_summary,
|
||||
},
|
||||
"prefill_batch": {
|
||||
"memory": prefill_batch_mem,
|
||||
"tensor_bytes": prefill_batch_tensor_summary,
|
||||
},
|
||||
"prefill_step": {
|
||||
"memory": prefill_step_mem,
|
||||
"running_requests": running_requests_summary,
|
||||
"finished_after_prefill": finished_after_prefill_summary,
|
||||
},
|
||||
"decode_batch": {
|
||||
"memory": decode_batch_mem,
|
||||
"tensor_bytes": decode_batch_tensor_summary,
|
||||
},
|
||||
"decode_outputs": {
|
||||
"memory": decode_step_mem,
|
||||
"tensor_bytes": decode_output_tensor_summary,
|
||||
},
|
||||
"end_to_end_worker": end_to_end_worker,
|
||||
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
|
||||
"end_to_end_worker_inference_mode": worker_inference_mode,
|
||||
}
|
||||
summary["top_rankings"] = top_rankings(summary)
|
||||
|
||||
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
|
||||
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
|
||||
print("[top_rankings]")
|
||||
for item in summary["top_rankings"]:
|
||||
print(f"- {item['name']}: {item['mb']:.3f} MB")
|
||||
print("[worker_peak]")
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
|
||||
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
|
||||
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
if worker_inference_mode is not None:
|
||||
print("[worker_peak_inference_mode]")
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"peak_label": worker_inference_mode["timeline"]["peak_label"],
|
||||
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
|
||||
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
print(f"[summary] {summary_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
180
tools/t2s_scheduler_prototype.py
Normal file
180
tools/t2s_scheduler_prototype.py
Normal file
@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
|
||||
if str(gpt_sovits_dir) not in sys.path:
|
||||
sys.path.append(str(gpt_sovits_dir))
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
|
||||
SchedulerRequestSpec,
|
||||
T2SFinishedItem,
|
||||
T2SRequestState,
|
||||
prepare_request_state,
|
||||
run_scheduler_continuous,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.")
|
||||
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
|
||||
parser.add_argument("--request-manifest", type=Path, default=None)
|
||||
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
|
||||
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
|
||||
parser.add_argument("--text", type=str, default=None)
|
||||
parser.add_argument("--text-lang", type=str, default="en")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--early-stop-num", type=int, default=-1)
|
||||
parser.add_argument("--max-steps", type=int, default=1500)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def set_seed(seed: int, use_cuda: bool) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def load_pipeline(config_path: Path):
|
||||
try:
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。"
|
||||
) from exc
|
||||
tts_config = TTS_Config(str(config_path))
|
||||
print(tts_config)
|
||||
return TTS(tts_config)
|
||||
|
||||
|
||||
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
if args.request_manifest is not None:
|
||||
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
|
||||
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, item in enumerate(raw_requests):
|
||||
text = item.get("text")
|
||||
text_file = item.get("text_file")
|
||||
if text is None and text_file is None:
|
||||
raise ValueError(f"request[{index}] must provide text or text_file")
|
||||
if text is None:
|
||||
text = Path(text_file).read_text(encoding="utf-8")
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=item.get("request_id", f"req_{index:03d}"),
|
||||
ref_audio_path=Path(item["ref_audio_path"]),
|
||||
prompt_text=item["prompt_text"],
|
||||
prompt_lang=item.get("prompt_lang", "zh"),
|
||||
text=text,
|
||||
text_lang=item.get("text_lang", "zh"),
|
||||
top_k=int(item.get("top_k", args.top_k)),
|
||||
top_p=float(item.get("top_p", args.top_p)),
|
||||
temperature=float(item.get("temperature", args.temperature)),
|
||||
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
|
||||
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
|
||||
ready_step=int(item.get("ready_step", 0)),
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8")
|
||||
return [
|
||||
SchedulerRequestSpec(
|
||||
request_id="req_000",
|
||||
ref_audio_path=args.ref_audio,
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_lang=args.prompt_lang,
|
||||
text=text,
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
"ready_step": int(state.ready_step),
|
||||
"ref_audio_path": str(state.ref_audio_path),
|
||||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||||
"all_phone_len": int(state.all_phones.shape[0]),
|
||||
"bert_len": int(state.all_bert_features.shape[-1]),
|
||||
"norm_text": state.norm_text,
|
||||
}
|
||||
for state in states
|
||||
]
|
||||
|
||||
|
||||
def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tts = load_pipeline(args.config)
|
||||
model = tts.t2s_model.model
|
||||
use_cuda = str(tts.configs.device).startswith("cuda")
|
||||
set_seed(args.seed, use_cuda)
|
||||
|
||||
request_specs = load_request_specs(args)
|
||||
states = [prepare_request_state(tts, spec) for spec in request_specs]
|
||||
finished = run_scheduler_continuous(model, states, max_steps=args.max_steps)
|
||||
|
||||
summary = {
|
||||
"request_count": len(states),
|
||||
"max_steps": args.max_steps,
|
||||
"requests": summarise_requests(states),
|
||||
"finished": summarise_finished(finished),
|
||||
}
|
||||
output_path = args.output_dir / "scheduler_prototype_summary.json"
|
||||
output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps(summary, ensure_ascii=False, indent=2))
|
||||
print(f"[saved] {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except ModuleNotFoundError as exc:
|
||||
print(f"[error] {exc}")
|
||||
raise SystemExit(1) from None
|
||||
Loading…
x
Reference in New Issue
Block a user