mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-12 21:08:11 +08:00
Implement batch processing for BERT and reference semantic tasks in TTS. Introduce StageLimiter for managing concurrent processing and enhance the TTS class with new methods for handling audio and semantic extraction. Update profiling metrics for better performance tracking during inference.
This commit is contained in:
parent
d245eb169c
commit
845b181360
@ -5,6 +5,7 @@ import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
|
||||
import torchaudio
|
||||
@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from tools.audio_sr import AP_BWE
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter
|
||||
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
|
||||
from TTS_infer_pack.prepare_ref_semantic_batch_worker import (
|
||||
PrepareRefSemanticBatchWorker,
|
||||
prepare_prompt_semantic_wav16k,
|
||||
)
|
||||
from sv import SV
|
||||
|
||||
resample_transform_dict = {}
|
||||
@ -442,11 +448,56 @@ class TTS:
|
||||
"upsample_rate": None,
|
||||
"overlapped_len": None,
|
||||
}
|
||||
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
|
||||
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2")))
|
||||
self.prepare_bert_batch_worker = None
|
||||
self.prepare_ref_semantic_batch_worker = None
|
||||
default_text_cpu_workers = 16
|
||||
self.prepare_text_cpu_workers = max(
|
||||
0,
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))),
|
||||
)
|
||||
self.prepare_text_cpu_executor = None
|
||||
if self.prepare_text_cpu_workers > 0:
|
||||
self.prepare_text_cpu_executor = ThreadPoolExecutor(
|
||||
max_workers=self.prepare_text_cpu_workers,
|
||||
thread_name_prefix="prepare-text-cpu",
|
||||
)
|
||||
|
||||
self._init_models()
|
||||
|
||||
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
|
||||
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
|
||||
bert_model=self.bert_model,
|
||||
tokenizer=self.bert_tokenizer,
|
||||
device=self.configs.device,
|
||||
stage_limiter=self.prepare_bert_stage_limiter,
|
||||
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")),
|
||||
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")),
|
||||
max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")),
|
||||
)
|
||||
if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0":
|
||||
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES")
|
||||
if ref_max_batch_samples is None:
|
||||
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000")
|
||||
self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker(
|
||||
ssl_model=self.cnhuhbert_model,
|
||||
vits_model=self.vits_model,
|
||||
device=self.configs.device,
|
||||
is_half=self.configs.is_half,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
stage_limiter=self.prepare_ref_audio_stage_limiter,
|
||||
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")),
|
||||
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")),
|
||||
max_batch_samples=int(ref_max_batch_samples),
|
||||
)
|
||||
|
||||
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
|
||||
self.bert_model, self.bert_tokenizer, self.configs.device
|
||||
self.bert_model,
|
||||
self.bert_tokenizer,
|
||||
self.configs.device,
|
||||
bert_stage_limiter=self.prepare_bert_stage_limiter,
|
||||
bert_batch_worker=self.prepare_bert_batch_worker,
|
||||
)
|
||||
|
||||
self.prompt_cache: dict = {
|
||||
@ -755,47 +806,52 @@ class TTS:
|
||||
Args:
|
||||
ref_audio_path: str, the path of the reference audio.
|
||||
"""
|
||||
self._set_prompt_semantic(ref_audio_path)
|
||||
self._set_ref_spec(ref_audio_path)
|
||||
bundle = self.extract_ref_audio_bundle(ref_audio_path)
|
||||
if self.prompt_cache["refer_spec"] in [[], None]:
|
||||
self.prompt_cache["refer_spec"] = [bundle["refer_spec"]]
|
||||
else:
|
||||
self.prompt_cache["refer_spec"][0] = bundle["refer_spec"]
|
||||
self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"]
|
||||
self.prompt_cache["raw_audio"] = bundle["raw_audio"]
|
||||
self.prompt_cache["raw_sr"] = bundle["raw_sr"]
|
||||
self._set_ref_audio_path(ref_audio_path)
|
||||
|
||||
def extract_prompt_semantic(self, ref_wav_path: str):
|
||||
zero_wav = np.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=np.float16 if self.configs.is_half else np.float32,
|
||||
)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
wav16k = wav16k.to(self.configs.device)
|
||||
zero_wav_torch = zero_wav_torch.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
wav16k = wav16k.half()
|
||||
zero_wav_torch = zero_wav_torch.half()
|
||||
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
|
||||
prompt_semantic = codes[0, 0].to(self.configs.device)
|
||||
return prompt_semantic
|
||||
|
||||
def extract_ref_spec(self, ref_audio_path: str):
|
||||
def _load_ref_audio_raw(self, ref_audio_path: str):
|
||||
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
|
||||
raw_audio = raw_audio.to(self.configs.device).float()
|
||||
return raw_audio.float(), int(raw_sr)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor):
|
||||
wav16k = wav16k.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
wav16k = wav16k.half()
|
||||
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
return codes[0, 0].to(self.configs.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
return self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||||
|
||||
def extract_prompt_semantic(self, ref_wav_path: str):
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path)
|
||||
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
|
||||
|
||||
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
raw_audio_device = raw_audio.to(self.configs.device).float()
|
||||
|
||||
if raw_sr != self.configs.sampling_rate:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
audio = raw_audio_device
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||||
else:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
audio = raw_audio_device
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
@ -820,8 +876,141 @@ class TTS:
|
||||
audio = None
|
||||
return spec, audio, raw_audio, raw_sr
|
||||
|
||||
def extract_text_features(self, text: str, language: str):
|
||||
return self.text_preprocessor.segment_and_extract_feature_for_text(text, language, self.configs.version)
|
||||
def extract_ref_spec(self, ref_audio_path: str):
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||
return self._extract_ref_spec_from_raw(raw_audio, raw_sr)
|
||||
|
||||
def extract_ref_audio_bundle(self, ref_audio_path: str):
|
||||
load_start = time.perf_counter()
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||||
if self.prepare_ref_semantic_batch_worker is None:
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
||||
prompt_semantic_start = time.perf_counter()
|
||||
prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
|
||||
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
||||
ref_spec_start = time.perf_counter()
|
||||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||||
audio_stage_wait_ms = float(limiter_stats["wait_ms"])
|
||||
audio_stage_slots = float(limiter_stats["slots"])
|
||||
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||||
"prompt_semantic_forward_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
|
||||
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"prompt_semantic_batch_size": 1.0,
|
||||
"prompt_semantic_batch_samples": 0.0,
|
||||
}
|
||||
ref_spec_wait_ms = 0.0
|
||||
return {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"refer_spec": refer_spec,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_ms": load_ms,
|
||||
"audio_stage_wait_ms": audio_stage_wait_ms,
|
||||
"audio_stage_slots": audio_stage_slots,
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_forward_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_scatter_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_slots": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||||
"prompt_semantic_batch_samples": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||||
),
|
||||
"ref_spec_wait_ms": ref_spec_wait_ms,
|
||||
"ref_spec_ms": ref_spec_ms,
|
||||
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
|
||||
},
|
||||
}
|
||||
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": 0.0,
|
||||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||||
"prompt_semantic_forward_ms": 0.0,
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_stage_slots": 0.0,
|
||||
"prompt_semantic_stage_inflight_peak": 0.0,
|
||||
"prompt_semantic_batch_size": 1.0,
|
||||
"prompt_semantic_batch_samples": 0.0,
|
||||
}
|
||||
if self.prepare_ref_semantic_batch_worker is not None:
|
||||
prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
|
||||
prompt_semantic_profile.update(worker_profile)
|
||||
prompt_semantic_ms = (
|
||||
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
|
||||
)
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
|
||||
ref_spec_start = time.perf_counter()
|
||||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||||
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
|
||||
ref_spec_limiter_stats["wait_ms"]
|
||||
)
|
||||
audio_stage_slots = max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
float(ref_spec_limiter_stats["slots"]),
|
||||
)
|
||||
audio_stage_inflight_peak = max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||||
float(ref_spec_limiter_stats["peak_inflight"]),
|
||||
)
|
||||
return {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"refer_spec": refer_spec,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_ms": load_ms,
|
||||
"audio_stage_wait_ms": audio_stage_wait_ms,
|
||||
"audio_stage_slots": audio_stage_slots,
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
"prompt_semantic_stage_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||||
"prompt_semantic_batch_samples": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||||
),
|
||||
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
|
||||
"ref_spec_ms": ref_spec_ms,
|
||||
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
|
||||
},
|
||||
}
|
||||
|
||||
def extract_text_features(self, text: str, language: str, profile: dict | None = None):
|
||||
return self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||
text, language, self.configs.version, profile=profile
|
||||
)
|
||||
|
||||
def _set_ref_audio_path(self, ref_audio_path):
|
||||
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -16,6 +18,7 @@ from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
class StageLimiter:
|
||||
def __init__(self, slots: int):
|
||||
self.slots = max(1, int(slots))
|
||||
self.semaphore = threading.BoundedSemaphore(self.slots)
|
||||
self.lock = threading.Lock()
|
||||
self.inflight = 0
|
||||
self.peak_inflight = 0
|
||||
|
||||
@contextmanager
|
||||
def enter(self):
|
||||
wait_start = time.perf_counter()
|
||||
self.semaphore.acquire()
|
||||
wait_ms = (time.perf_counter() - wait_start) * 1000.0
|
||||
with self.lock:
|
||||
self.inflight += 1
|
||||
current_inflight = self.inflight
|
||||
if current_inflight > self.peak_inflight:
|
||||
self.peak_inflight = current_inflight
|
||||
peak_inflight = self.peak_inflight
|
||||
try:
|
||||
yield {
|
||||
"wait_ms": wait_ms,
|
||||
"inflight": current_inflight,
|
||||
"peak_inflight": peak_inflight,
|
||||
"slots": self.slots,
|
||||
}
|
||||
finally:
|
||||
with self.lock:
|
||||
self.inflight = max(0, self.inflight - 1)
|
||||
self.semaphore.release()
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.lock:
|
||||
return {
|
||||
"slots": self.slots,
|
||||
"inflight": self.inflight,
|
||||
"peak_inflight": self.peak_inflight,
|
||||
}
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||
def __init__(
|
||||
self,
|
||||
bert_model: AutoModelForMaskedLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
device: torch.device,
|
||||
bert_stage_limiter: StageLimiter | None = None,
|
||||
bert_batch_worker: PrepareBertBatchWorker | None = None,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
self.bert_stage_limiter = bert_stage_limiter
|
||||
self.bert_batch_worker = bert_batch_worker
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
@ -115,86 +166,136 @@ class TextPreprocessor:
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1"
|
||||
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
return self.get_phones_and_bert(text, language, version, profile=profile)
|
||||
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
with self.bert_lock:
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||
)
|
||||
if same_group:
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text,"ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text,"ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
else:
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
return textlist, langlist
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||
def get_phones_and_bert(
|
||||
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
|
||||
):
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
textlist, langlist = self._split_text_by_language(text, language)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for segment_text, segment_lang in zip(textlist, langlist):
|
||||
phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
return phones, bert, norm_text
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile)
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
return phones, bert, norm_text
|
||||
|
||||
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
|
||||
if profile is None:
|
||||
return
|
||||
profile[key] = float(profile.get(key, 0.0)) + float(value)
|
||||
|
||||
def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None:
|
||||
if profile is None:
|
||||
return
|
||||
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
|
||||
if self.bert_batch_worker is not None:
|
||||
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
|
||||
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
|
||||
self._update_profile_peak(
|
||||
profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)
|
||||
)
|
||||
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
|
||||
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
|
||||
return feature
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
|
||||
if self.bert_stage_limiter is None:
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.bert_stage_limiter.enter() as limiter_stats:
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"])
|
||||
self._accumulate_profile(profile, "bert_forward_ms", forward_ms)
|
||||
self._accumulate_profile(profile, "bert_calls", 1.0)
|
||||
self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"])
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(limiter_stats["slots"])
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
@ -209,10 +310,10 @@ class TextPreprocessor:
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
|
||||
else:
|
||||
feature = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
@ -236,4 +337,4 @@ class TextPreprocessor:
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
return result
|
||||
|
||||
197
GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py
Normal file
197
GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py
Normal file
@ -0,0 +1,197 @@
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertFeatureTask:
|
||||
norm_text: str
|
||||
word2ph: List[int]
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
result_feature: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareBertBatchWorker:
|
||||
def __init__(
|
||||
self,
|
||||
bert_model,
|
||||
tokenizer,
|
||||
device,
|
||||
stage_limiter=None,
|
||||
batch_window_ms: int = 5,
|
||||
max_batch_items: int = 16,
|
||||
max_batch_tokens: int = 4096,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.stage_limiter = stage_limiter
|
||||
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
|
||||
self.max_batch_items = max(1, int(max_batch_items))
|
||||
self.max_batch_tokens = max(16, int(max_batch_tokens))
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[BertFeatureTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.total_batches = 0
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_peak = 0
|
||||
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
|
||||
return max(1, len(task.norm_text) + 2)
|
||||
|
||||
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.result_feature is not None
|
||||
return task.result_feature, dict(task.profile)
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
"pending": len(self.pending_tasks),
|
||||
"pending_peak": self.pending_peak,
|
||||
"total_submitted": self.total_submitted,
|
||||
"total_finished": self.total_finished,
|
||||
"total_batches": self.total_batches,
|
||||
"active_batch_size": self.active_batch_size,
|
||||
"active_batch_peak": self.active_batch_peak,
|
||||
"batch_window_ms": int(self.batch_window_s * 1000.0),
|
||||
"max_batch_items": self.max_batch_items,
|
||||
"max_batch_tokens": self.max_batch_tokens,
|
||||
}
|
||||
|
||||
def _collect_batch(self) -> List[BertFeatureTask]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
|
||||
batch_tokens = self._estimate_task_tokens(batch[0])
|
||||
deadline = time.perf_counter() + self.batch_window_s
|
||||
|
||||
while len(batch) < self.max_batch_items:
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.condition.wait(timeout=remaining)
|
||||
continue
|
||||
next_task = self.pending_tasks[0]
|
||||
next_tokens = self._estimate_task_tokens(next_task)
|
||||
if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens:
|
||||
break
|
||||
batch.append(self.pending_tasks.popleft())
|
||||
batch_tokens += next_tokens
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
if self.active_batch_size > self.active_batch_peak:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
return batch
|
||||
|
||||
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
|
||||
with self.condition:
|
||||
self.active_batch_size = 0
|
||||
self.total_batches += 1
|
||||
self.total_finished += len(batch)
|
||||
|
||||
def _run_batch(self, batch: List[BertFeatureTask]) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
texts = [task.norm_text for task in batch]
|
||||
batch_tokens = sum(self._estimate_task_tokens(task) for task in batch)
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
if self.stage_limiter is None:
|
||||
tokenize_start = time.perf_counter()
|
||||
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
||||
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
|
||||
attention_mask_cpu = inputs["attention_mask"].cpu()
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.device)
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.stage_limiter.enter() as limiter_stats:
|
||||
tokenize_start = time.perf_counter()
|
||||
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
||||
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
|
||||
attention_mask_cpu = inputs["attention_mask"].cpu()
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.device)
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
|
||||
hidden = outputs["hidden_states"][-3].detach().cpu()
|
||||
scatter_start = time.perf_counter()
|
||||
for batch_index, task in enumerate(batch):
|
||||
try:
|
||||
text_len = len(task.word2ph)
|
||||
if text_len != len(task.norm_text):
|
||||
raise AssertionError(
|
||||
f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}"
|
||||
)
|
||||
seq_len = int(attention_mask_cpu[batch_index].sum().item())
|
||||
char_features = hidden[batch_index, 1 : seq_len - 1]
|
||||
if char_features.shape[0] != text_len:
|
||||
raise AssertionError(
|
||||
f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}"
|
||||
)
|
||||
phone_level_feature = []
|
||||
for char_index, repeat_count in enumerate(task.word2ph):
|
||||
phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1))
|
||||
task.result_feature = torch.cat(phone_level_feature, dim=0).T
|
||||
task.profile = {
|
||||
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
||||
"bert_forward_ms": float(forward_ms),
|
||||
"bert_tokenize_ms": float(tokenize_ms),
|
||||
"bert_scatter_ms": 0.0,
|
||||
"bert_calls": 1.0,
|
||||
"bert_stage_slots": float(limiter_stats["slots"]),
|
||||
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"bert_batch_size": float(len(batch)),
|
||||
"bert_batch_tokens": float(batch_tokens),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
|
||||
for task in batch:
|
||||
if task.result_feature is not None:
|
||||
task.profile["bert_scatter_ms"] = float(scatter_ms)
|
||||
task.done_event.set()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
task.done_event.set()
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
262
GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py
Normal file
262
GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py
Normal file
@ -0,0 +1,262 @@
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, List, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
REF_AUDIO_MIN_SAMPLES_16K = 48000
|
||||
REF_AUDIO_MAX_SAMPLES_16K = 160000
|
||||
|
||||
|
||||
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
|
||||
wav_mono = raw_audio
|
||||
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
|
||||
wav_mono = wav_mono.mean(0, keepdim=True)
|
||||
wav16k = wav_mono.squeeze(0).cpu().numpy()
|
||||
if raw_sr != 16000:
|
||||
wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000)
|
||||
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
|
||||
raise OSError("参考音频在3~10秒范围外,请更换!")
|
||||
wav16k = np.ascontiguousarray(wav16k, dtype=np.float32)
|
||||
if zero_wav_samples > 0:
|
||||
wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0)
|
||||
return torch.from_numpy(wav16k)
|
||||
|
||||
|
||||
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:
|
||||
if conv1d is None:
|
||||
return input_lengths.to(dtype=torch.long)
|
||||
kernel_size = int(conv1d.kernel_size[0])
|
||||
stride = int(conv1d.stride[0])
|
||||
padding = int(conv1d.padding[0])
|
||||
dilation = int(conv1d.dilation[0])
|
||||
output_lengths = torch.div(
|
||||
input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
|
||||
stride,
|
||||
rounding_mode="floor",
|
||||
) + 1
|
||||
return torch.clamp(output_lengths, min=0).to(dtype=torch.long)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RefSemanticTask:
|
||||
raw_audio: torch.Tensor
|
||||
raw_sr: int
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
result_prompt_semantic: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareRefSemanticBatchWorker:
|
||||
def __init__(
|
||||
self,
|
||||
ssl_model,
|
||||
vits_model,
|
||||
device,
|
||||
is_half: bool,
|
||||
zero_wav_samples: int,
|
||||
stage_limiter=None,
|
||||
batch_window_ms: int = 5,
|
||||
max_batch_items: int = 8,
|
||||
max_batch_samples: int = 960000,
|
||||
):
|
||||
self.ssl_model = ssl_model
|
||||
self.vits_model = vits_model
|
||||
self.device = device
|
||||
self.is_half = bool(is_half)
|
||||
self.zero_wav_samples = max(0, int(zero_wav_samples))
|
||||
self.stage_limiter = stage_limiter
|
||||
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
|
||||
self.max_batch_items = max(1, int(max_batch_items))
|
||||
self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples))
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[RefSemanticTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.total_batches = 0
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_peak = 0
|
||||
self.active_batch_samples = 0
|
||||
self.active_batch_samples_peak = 0
|
||||
self.worker_thread = threading.Thread(
|
||||
target=self._run_loop,
|
||||
name="prepare-ref-semantic-batch-worker",
|
||||
daemon=True,
|
||||
)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _estimate_task_samples(self, task: RefSemanticTask) -> int:
|
||||
raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0
|
||||
base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr))))
|
||||
return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples
|
||||
|
||||
def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr))
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.result_prompt_semantic is not None
|
||||
return task.result_prompt_semantic, dict(task.profile)
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
"pending": len(self.pending_tasks),
|
||||
"pending_peak": self.pending_peak,
|
||||
"total_submitted": self.total_submitted,
|
||||
"total_finished": self.total_finished,
|
||||
"total_batches": self.total_batches,
|
||||
"active_batch_size": self.active_batch_size,
|
||||
"active_batch_peak": self.active_batch_peak,
|
||||
"active_batch_samples": self.active_batch_samples,
|
||||
"active_batch_samples_peak": self.active_batch_samples_peak,
|
||||
"batch_window_ms": int(self.batch_window_s * 1000.0),
|
||||
"max_batch_items": self.max_batch_items,
|
||||
"max_batch_samples": self.max_batch_samples,
|
||||
}
|
||||
|
||||
def _collect_batch(self) -> List[RefSemanticTask]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()]
|
||||
batch_samples = self._estimate_task_samples(batch[0])
|
||||
deadline = time.perf_counter() + self.batch_window_s
|
||||
|
||||
while len(batch) < self.max_batch_items:
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.condition.wait(timeout=remaining)
|
||||
continue
|
||||
next_task = self.pending_tasks[0]
|
||||
next_samples = self._estimate_task_samples(next_task)
|
||||
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
|
||||
break
|
||||
batch.append(self.pending_tasks.popleft())
|
||||
batch_samples += next_samples
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
self.active_batch_samples = batch_samples
|
||||
if self.active_batch_size > self.active_batch_peak:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
if self.active_batch_samples > self.active_batch_samples_peak:
|
||||
self.active_batch_samples_peak = self.active_batch_samples
|
||||
return batch
|
||||
|
||||
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||
with self.condition:
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_samples = 0
|
||||
self.total_batches += 1
|
||||
self.total_finished += len(batch)
|
||||
|
||||
def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor:
|
||||
model = self.ssl_model.model
|
||||
if hasattr(model, "_get_feature_vector_attention_mask"):
|
||||
feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask)
|
||||
return feature_mask.to(dtype=torch.long).sum(dim=1)
|
||||
raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1)
|
||||
if hasattr(model, "_get_feat_extract_output_lengths"):
|
||||
return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long)
|
||||
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _run_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
prepared_start = time.perf_counter()
|
||||
prepared_wavs = [
|
||||
prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch
|
||||
]
|
||||
cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0
|
||||
wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long)
|
||||
batch_samples = int(wav_lengths.sum().item())
|
||||
max_wav_len = int(wav_lengths.max().item())
|
||||
|
||||
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
|
||||
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
|
||||
for batch_index, wav in enumerate(prepared_wavs):
|
||||
wav_len = int(wav.shape[0])
|
||||
input_values_cpu[batch_index, :wav_len] = wav
|
||||
attention_mask_cpu[batch_index, :wav_len] = 1
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
if self.stage_limiter is None:
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
forward_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.stage_limiter.enter() as limiter_stats:
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
forward_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
|
||||
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
|
||||
scatter_start = time.perf_counter()
|
||||
for batch_index, task in enumerate(batch):
|
||||
try:
|
||||
code_len = int(code_lengths[batch_index].item())
|
||||
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
|
||||
task.profile = {
|
||||
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
||||
"prompt_semantic_forward_ms": float(forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_calls": 1.0,
|
||||
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
|
||||
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"prompt_semantic_batch_size": float(len(batch)),
|
||||
"prompt_semantic_batch_samples": float(batch_samples),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
|
||||
for task in batch:
|
||||
if task.result_prompt_semantic is not None:
|
||||
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
|
||||
task.done_event.set()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
task.done_event.set()
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import time
|
||||
@ -123,31 +124,58 @@ def prepare_request_state(
|
||||
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
|
||||
text = spec.text.strip("\n")
|
||||
|
||||
_sync_device(device)
|
||||
prompt_text_features_start = time.perf_counter()
|
||||
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(prompt_text, spec.prompt_lang)
|
||||
_sync_device(device)
|
||||
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
|
||||
prompt_text_profile: Dict[str, float] = {}
|
||||
text_features_profile: Dict[str, float] = {}
|
||||
text_feature_pair_start = time.perf_counter()
|
||||
prompt_future: Future | None = None
|
||||
|
||||
def _extract_prompt_features():
|
||||
_sync_device(device)
|
||||
prompt_start = time.perf_counter()
|
||||
result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile)
|
||||
_sync_device(device)
|
||||
return result, (time.perf_counter() - prompt_start) * 1000.0
|
||||
|
||||
if getattr(tts, "prepare_text_cpu_executor", None) is not None:
|
||||
prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features)
|
||||
|
||||
_sync_device(device)
|
||||
text_features_start = time.perf_counter()
|
||||
phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang)
|
||||
phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile)
|
||||
_sync_device(device)
|
||||
text_features_ms = (time.perf_counter() - text_features_start) * 1000.0
|
||||
|
||||
if prompt_future is None:
|
||||
_sync_device(device)
|
||||
prompt_text_features_start = time.perf_counter()
|
||||
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(
|
||||
prompt_text, spec.prompt_lang, profile=prompt_text_profile
|
||||
)
|
||||
_sync_device(device)
|
||||
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
|
||||
prompt_text_profile["parallel_future_wait_ms"] = 0.0
|
||||
else:
|
||||
prompt_wait_start = time.perf_counter()
|
||||
(prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result()
|
||||
prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0
|
||||
|
||||
text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0
|
||||
if phones is None:
|
||||
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
|
||||
|
||||
_sync_device(device)
|
||||
prompt_semantic_start = time.perf_counter()
|
||||
prompt_semantic = tts.extract_prompt_semantic(str(spec.ref_audio_path)).long()
|
||||
ref_audio_bundle_start = time.perf_counter()
|
||||
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
|
||||
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
|
||||
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
|
||||
raw_audio = ref_audio_bundle["raw_audio"]
|
||||
raw_sr = int(ref_audio_bundle["raw_sr"])
|
||||
_sync_device(device)
|
||||
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
||||
|
||||
_sync_device(device)
|
||||
ref_spec_start = time.perf_counter()
|
||||
spec_audio, audio_16k, raw_audio, raw_sr = tts.extract_ref_spec(str(spec.ref_audio_path))
|
||||
_sync_device(device)
|
||||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||||
ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0
|
||||
bundle_profile = ref_audio_bundle.get("profile", {})
|
||||
prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
|
||||
ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0))
|
||||
audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0))
|
||||
|
||||
_sync_device(device)
|
||||
tensorize_start = time.perf_counter()
|
||||
@ -164,8 +192,43 @@ def prepare_request_state(
|
||||
prepare_profile = {
|
||||
"prompt_text_features_ms": prompt_text_features_ms,
|
||||
"text_features_ms": text_features_ms,
|
||||
"prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)),
|
||||
"prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)),
|
||||
"prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)),
|
||||
"prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)),
|
||||
"prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)),
|
||||
"prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)),
|
||||
"prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)),
|
||||
"prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)),
|
||||
"text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)),
|
||||
"text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)),
|
||||
"text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)),
|
||||
"text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)),
|
||||
"text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)),
|
||||
"text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)),
|
||||
"text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)),
|
||||
"text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"text_feature_pair_ms": text_feature_pair_ms,
|
||||
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
|
||||
"audio_load_ms": audio_load_ms,
|
||||
"audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)),
|
||||
"audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)),
|
||||
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
"prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||||
"prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)),
|
||||
"prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)),
|
||||
"ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)),
|
||||
"ref_spec_ms": ref_spec_ms,
|
||||
"ref_audio_bundle_ms": ref_audio_bundle_ms,
|
||||
"tensorize_ms": tensorize_ms,
|
||||
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
|
||||
"wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0,
|
||||
@ -186,7 +249,7 @@ def prepare_request_state(
|
||||
prompt_semantic=prompt_semantic,
|
||||
refer_spec=(spec_audio, audio_16k),
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=int(raw_sr),
|
||||
raw_sr=raw_sr,
|
||||
top_k=spec.top_k,
|
||||
top_p=spec.top_p,
|
||||
temperature=spec.temperature,
|
||||
@ -409,9 +472,18 @@ def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
return F.pad(mask, (pad_len, 0), value=True)
|
||||
|
||||
|
||||
def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
if mask.shape[-1] > target_len:
|
||||
return mask[:, :, :, -target_len:]
|
||||
if mask.shape[-1] < target_len:
|
||||
return _pad_decode_mask_left(mask, target_len)
|
||||
return mask
|
||||
|
||||
|
||||
def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor:
|
||||
expected_mask_len = running_request.k_cache[0].shape[1] + 1
|
||||
if running_request.decode_attn_mask is not None:
|
||||
return running_request.decode_attn_mask
|
||||
return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len)
|
||||
current_mask_len = running_request.k_cache[0].shape[1] + 1
|
||||
return torch.zeros(
|
||||
(1, 1, 1, current_mask_len),
|
||||
@ -481,17 +553,19 @@ def run_prefill_step(
|
||||
real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len
|
||||
request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache]
|
||||
request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache]
|
||||
request_decode_attn_mask = None
|
||||
if decode_attn_mask is not None:
|
||||
request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone()
|
||||
request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1)
|
||||
if not request_decode_attn_mask.any().item():
|
||||
request_decode_attn_mask = None
|
||||
|
||||
running_requests.append(
|
||||
T2SRunningRequest(
|
||||
state=state,
|
||||
y_sequence=new_history,
|
||||
prefix_len=prefix_len,
|
||||
decode_attn_mask=(
|
||||
None
|
||||
if decode_attn_mask is None
|
||||
else decode_attn_mask[batch_index : batch_index + 1].clone()
|
||||
),
|
||||
decode_attn_mask=request_decode_attn_mask,
|
||||
k_cache=request_k_cache,
|
||||
v_cache=request_v_cache,
|
||||
step_idx=1,
|
||||
@ -603,6 +677,9 @@ def run_decode_step_for_running(
|
||||
batch_index : batch_index + 1, :, :, -current_decode_mask_len:
|
||||
]
|
||||
next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False)
|
||||
next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1)
|
||||
if not next_decode_attn_mask.any().item():
|
||||
next_decode_attn_mask = None
|
||||
next_running.append(
|
||||
T2SRunningRequest(
|
||||
state=running_request.state,
|
||||
|
||||
64
api_v3.py
64
api_v3.py
@ -261,8 +261,9 @@ class SchedulerDebugWorker:
|
||||
self.tts = tts
|
||||
self.max_steps = max_steps
|
||||
self.micro_batch_wait_s = micro_batch_wait_ms / 1000.0
|
||||
self.prepare_lock = threading.Lock()
|
||||
self.condition = threading.Condition()
|
||||
self.prepare_inflight = 0
|
||||
self.prepare_peak_inflight = 0
|
||||
self.pending_jobs: List[SchedulerPendingJob] = []
|
||||
self.running_requests: List[T2SRunningRequest] = []
|
||||
self.job_map: dict[str, SchedulerPendingJob] = {}
|
||||
@ -282,8 +283,20 @@ class SchedulerDebugWorker:
|
||||
pass
|
||||
|
||||
def prepare_state(self, spec: SchedulerRequestSpec) -> T2SRequestState:
|
||||
with self.prepare_lock:
|
||||
return prepare_request_state(self.tts, spec)
|
||||
with self.condition:
|
||||
self.prepare_inflight += 1
|
||||
prepare_inflight_on_enter = self.prepare_inflight
|
||||
if self.prepare_inflight > self.prepare_peak_inflight:
|
||||
self.prepare_peak_inflight = self.prepare_inflight
|
||||
prepare_peak_inflight = self.prepare_peak_inflight
|
||||
try:
|
||||
state = prepare_request_state(self.tts, spec)
|
||||
state.prepare_profile["worker_prepare_inflight_on_enter"] = float(prepare_inflight_on_enter)
|
||||
state.prepare_profile["worker_prepare_peak_inflight"] = float(prepare_peak_inflight)
|
||||
return state
|
||||
finally:
|
||||
with self.condition:
|
||||
self.prepare_inflight = max(0, self.prepare_inflight - 1)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
@ -363,9 +376,28 @@ class SchedulerDebugWorker:
|
||||
|
||||
def get_state(self) -> dict:
|
||||
with self.condition:
|
||||
bert_stage = self.tts.prepare_bert_stage_limiter.snapshot()
|
||||
ref_audio_stage = self.tts.prepare_ref_audio_stage_limiter.snapshot()
|
||||
bert_batch_worker = (
|
||||
None
|
||||
if self.tts.prepare_bert_batch_worker is None
|
||||
else self.tts.prepare_bert_batch_worker.snapshot()
|
||||
)
|
||||
ref_semantic_batch_worker = (
|
||||
None
|
||||
if self.tts.prepare_ref_semantic_batch_worker is None
|
||||
else self.tts.prepare_ref_semantic_batch_worker.snapshot()
|
||||
)
|
||||
return {
|
||||
"pending_jobs": len(self.pending_jobs),
|
||||
"running_requests": len(self.running_requests),
|
||||
"prepare_inflight": self.prepare_inflight,
|
||||
"prepare_peak_inflight": self.prepare_peak_inflight,
|
||||
"prepare_text_cpu_workers": int(getattr(self.tts, "prepare_text_cpu_workers", 0)),
|
||||
"prepare_bert_stage": bert_stage,
|
||||
"prepare_bert_batch_worker": bert_batch_worker,
|
||||
"prepare_ref_audio_stage": ref_audio_stage,
|
||||
"prepare_ref_semantic_batch_worker": ref_semantic_batch_worker,
|
||||
"tracked_jobs": len(self.job_map),
|
||||
"total_submitted": self.total_submitted,
|
||||
"total_finished": self.total_finished,
|
||||
@ -907,10 +939,36 @@ async def tts_scheduler_submit_handle(request: Scheduler_Submit_Request):
|
||||
{
|
||||
"X-Prepare-Prompt-Text-Ms": f"{float(prepare_profile.get('prompt_text_features_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Text-Ms": f"{float(prepare_profile.get('text_features_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Wait-Ms": f"{float(prepare_profile.get('prompt_text_bert_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Wait-Ms": f"{float(prepare_profile.get('text_bert_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Forward-Ms": f"{float(prepare_profile.get('prompt_text_bert_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Target-Bert-Forward-Ms": f"{float(prepare_profile.get('text_bert_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Bert-Batch-Size-Peak": str(
|
||||
int(prepare_profile.get("prompt_text_bert_batch_size_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Target-Bert-Batch-Size-Peak": str(
|
||||
int(prepare_profile.get("text_bert_batch_size_peak", 0.0))
|
||||
),
|
||||
"X-Prepare-Text-Pair-Wall-Ms": f"{float(prepare_profile.get('text_feature_pair_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Text-CPU-Workers": str(int(prepare_profile.get("text_cpu_parallel_workers", 0.0))),
|
||||
"X-Prepare-Audio-Load-Ms": f"{float(prepare_profile.get('audio_load_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Audio-Stage-Wait-Ms": f"{float(prepare_profile.get('audio_stage_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Ms": f"{float(prepare_profile.get('prompt_semantic_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Wait-Ms": f"{float(prepare_profile.get('prompt_semantic_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-CPU-Ms": f"{float(prepare_profile.get('prompt_semantic_cpu_prepare_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Forward-Ms": f"{float(prepare_profile.get('prompt_semantic_forward_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Prompt-Semantic-Batch-Size": str(
|
||||
int(prepare_profile.get("prompt_semantic_batch_size", 0.0))
|
||||
),
|
||||
"X-Prepare-Ref-Spec-Ms": f"{float(prepare_profile.get('ref_spec_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Ref-Spec-Wait-Ms": f"{float(prepare_profile.get('ref_spec_wait_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Ref-Bundle-Ms": f"{float(prepare_profile.get('ref_audio_bundle_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Tensorize-Ms": f"{float(prepare_profile.get('tensorize_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Profile-Wall-Ms": f"{float(prepare_profile.get('wall_total_ms', 0.0)):.3f}",
|
||||
"X-Prepare-Inflight-On-Enter": str(
|
||||
int(prepare_profile.get("worker_prepare_inflight_on_enter", 0.0))
|
||||
),
|
||||
"X-Prepare-Inflight-Peak": str(int(prepare_profile.get("worker_prepare_peak_inflight", 0.0))),
|
||||
}
|
||||
)
|
||||
return Response(audio_data, media_type=f"audio/{job.media_type}", headers=headers)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user