Merge 845b181360b5c5f4a5ed827ca01cfd5dfd8f58b1 into 2d9193b0d3c0eae0c3a14d8c68a839f1bae157dc

This commit is contained in:
白菜工厂1145号员工 2026-03-08 21:55:59 +00:00 committed by GitHub
commit 1c07b3339e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 4655 additions and 208 deletions

View File

@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module):
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
self.last_infer_stats = {}
def _set_last_infer_stats(self, stats):
self.last_infer_stats = stats
def get_last_infer_stats(self):
return dict(self.last_infer_stats)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module):
repetition_penalty: float = 1.35,
**kwargs,
):
requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True))
if prompts is None:
self._set_last_infer_stats(
{
"infer_mode": "batch_infer_prompt_free_fallback",
"requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath,
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": None,
"fastpath_hit": False,
"generated_token_count": 0,
"generated_token_count_list": [],
}
)
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(
x,
@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module):
)
max_len = kwargs.get("max_len", x_lens.max())
enable_mask_free_fastpath = requested_enable_mask_free_fastpath
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module):
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None] * y.shape[0]
decode_attn_mask = attn_mask
prefill_after_mask_all_visible = None
fastpath_hit = False
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(
xy_pos, k_cache, v_cache, decode_attn_mask
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
prefill_after_mask_all_visible = not attn_mask.any().item()
if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible:
decode_attn_mask = None
fastpath_hit = True
else:
decode_attn_mask = attn_mask
else:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
if decode_attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
decode_attn_mask = attn_mask
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module):
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if decode_attn_mask is not None:
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
decode_attn_mask = attn_mask
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module):
if idx_list[i] is None:
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
self._set_last_infer_stats(
{
"infer_mode": "batch_infer",
"requested_enable_mask_free_fastpath": enable_mask_free_fastpath,
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": prefill_after_mask_all_visible,
"fastpath_hit": fastpath_hit,
"generated_token_count": int(sum(idx_list)),
"generated_token_count_list": [int(item) for item in idx_list],
"max_len": int(max_len),
}
)
if ref_free:
return y_list, [0] * x.shape[0]
# print(idx_list)
@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module):
y_list.append(y[0])
idx_list.append(idx)
self._set_last_infer_stats(
{
"infer_mode": "naive_batched",
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": None,
"fastpath_hit": False,
"generated_token_count": int(sum(idx_list)),
"generated_token_count_list": [int(item) for item in idx_list],
}
)
return y_list, idx_list
def infer_panel_naive(
@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module):
if not streaming_mode:
generated_token_count = max(int(y.shape[1] - prefix_len), 0)
self._set_last_infer_stats(
{
"infer_mode": "naive",
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
"batch_size": int(x.shape[0]),
"prefill_after_mask_all_visible": True if prompts is not None else None,
"fastpath_hit": True if prompts is not None else False,
"generated_token_count": generated_token_count,
"generated_token_count_list": [generated_token_count],
}
)
if ref_free:
yield y, 0
yield y, idx

View File

@ -5,6 +5,7 @@ import random
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
import torchaudio
@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
from TTS_infer_pack.prepare_ref_semantic_batch_worker import (
PrepareRefSemanticBatchWorker,
prepare_prompt_semantic_wav16k,
)
from sv import SV
resample_transform_dict = {}
@ -442,11 +448,56 @@ class TTS:
"upsample_rate": None,
"overlapped_len": None,
}
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2")))
self.prepare_bert_batch_worker = None
self.prepare_ref_semantic_batch_worker = None
default_text_cpu_workers = 16
self.prepare_text_cpu_workers = max(
0,
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))),
)
self.prepare_text_cpu_executor = None
if self.prepare_text_cpu_workers > 0:
self.prepare_text_cpu_executor = ThreadPoolExecutor(
max_workers=self.prepare_text_cpu_workers,
thread_name_prefix="prepare-text-cpu",
)
self._init_models()
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
bert_model=self.bert_model,
tokenizer=self.bert_tokenizer,
device=self.configs.device,
stage_limiter=self.prepare_bert_stage_limiter,
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")),
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")),
max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")),
)
if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0":
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES")
if ref_max_batch_samples is None:
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000")
self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker(
ssl_model=self.cnhuhbert_model,
vits_model=self.vits_model,
device=self.configs.device,
is_half=self.configs.is_half,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
stage_limiter=self.prepare_ref_audio_stage_limiter,
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")),
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")),
max_batch_samples=int(ref_max_batch_samples),
)
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
self.bert_model, self.bert_tokenizer, self.configs.device
self.bert_model,
self.bert_tokenizer,
self.configs.device,
bert_stage_limiter=self.prepare_bert_stage_limiter,
bert_batch_worker=self.prepare_bert_batch_worker,
)
self.prompt_cache: dict = {
@ -755,33 +806,52 @@ class TTS:
Args:
ref_audio_path: str, the path of the reference audio.
"""
self._set_prompt_semantic(ref_audio_path)
self._set_ref_spec(ref_audio_path)
bundle = self.extract_ref_audio_bundle(ref_audio_path)
if self.prompt_cache["refer_spec"] in [[], None]:
self.prompt_cache["refer_spec"] = [bundle["refer_spec"]]
else:
self.prompt_cache["refer_spec"][0] = bundle["refer_spec"]
self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"]
self.prompt_cache["raw_audio"] = bundle["raw_audio"]
self.prompt_cache["raw_sr"] = bundle["raw_sr"]
self._set_ref_audio_path(ref_audio_path)
def _set_ref_audio_path(self, ref_audio_path):
self.prompt_cache["ref_audio_path"] = ref_audio_path
def _set_ref_spec(self, ref_audio_path):
spec_audio = self._get_ref_spec(ref_audio_path)
if self.prompt_cache["refer_spec"] in [[], None]:
self.prompt_cache["refer_spec"] = [spec_audio]
else:
self.prompt_cache["refer_spec"][0] = spec_audio
def _get_ref_spec(self, ref_audio_path):
def _load_ref_audio_raw(self, ref_audio_path: str):
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
raw_audio = raw_audio.to(self.configs.device).float()
self.prompt_cache["raw_audio"] = raw_audio
self.prompt_cache["raw_sr"] = raw_sr
return raw_audio.float(), int(raw_sr)
@torch.inference_mode()
def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor):
wav16k = wav16k.to(self.configs.device)
if self.configs.is_half:
wav16k = wav16k.half()
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
codes = self.vits_model.extract_latent(hubert_feature)
return codes[0, 0].to(self.configs.device)
@torch.inference_mode()
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
wav16k = prepare_prompt_semantic_wav16k(
raw_audio=raw_audio,
raw_sr=raw_sr,
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
)
return self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
def extract_prompt_semantic(self, ref_wav_path: str):
raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path)
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
raw_audio_device = raw_audio.to(self.configs.device).float()
if raw_sr != self.configs.sampling_rate:
audio = raw_audio.to(self.configs.device)
audio = raw_audio_device
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
else:
audio = raw_audio.to(self.configs.device)
audio = raw_audio_device
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
@ -804,33 +874,163 @@ class TTS:
audio = audio.half()
else:
audio = None
return spec, audio, raw_audio, raw_sr
def extract_ref_spec(self, ref_audio_path: str):
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
return self._extract_ref_spec_from_raw(raw_audio, raw_sr)
def extract_ref_audio_bundle(self, ref_audio_path: str):
load_start = time.perf_counter()
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
load_ms = (time.perf_counter() - load_start) * 1000.0
if self.prepare_ref_semantic_batch_worker is None:
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
prompt_semantic_start = time.perf_counter()
prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
ref_spec_start = time.perf_counter()
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
audio_stage_wait_ms = float(limiter_stats["wait_ms"])
audio_stage_slots = float(limiter_stats["slots"])
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
prompt_semantic_profile = {
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
"prompt_semantic_cpu_prepare_ms": 0.0,
"prompt_semantic_forward_ms": prompt_semantic_ms,
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"prompt_semantic_batch_size": 1.0,
"prompt_semantic_batch_samples": 0.0,
}
ref_spec_wait_ms = 0.0
return {
"prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_ms": load_ms,
"audio_stage_wait_ms": audio_stage_wait_ms,
"audio_stage_slots": audio_stage_slots,
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
"prompt_semantic_forward_ms": float(
prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)
),
"prompt_semantic_scatter_ms": float(
prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)
),
"prompt_semantic_stage_slots": float(
prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)
),
"prompt_semantic_stage_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
),
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
"prompt_semantic_batch_samples": float(
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
),
"ref_spec_wait_ms": ref_spec_wait_ms,
"ref_spec_ms": ref_spec_ms,
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
},
}
prompt_semantic_profile = {
"prompt_semantic_wait_ms": 0.0,
"prompt_semantic_cpu_prepare_ms": 0.0,
"prompt_semantic_forward_ms": 0.0,
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_stage_slots": 0.0,
"prompt_semantic_stage_inflight_peak": 0.0,
"prompt_semantic_batch_size": 1.0,
"prompt_semantic_batch_samples": 0.0,
}
if self.prepare_ref_semantic_batch_worker is not None:
prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
prompt_semantic_profile.update(worker_profile)
prompt_semantic_ms = (
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
)
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
ref_spec_start = time.perf_counter()
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
ref_spec_limiter_stats["wait_ms"]
)
audio_stage_slots = max(
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
float(ref_spec_limiter_stats["slots"]),
)
audio_stage_inflight_peak = max(
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
float(ref_spec_limiter_stats["peak_inflight"]),
)
return {
"prompt_semantic": prompt_semantic,
"refer_spec": refer_spec,
"raw_audio": raw_audio,
"raw_sr": raw_sr,
"profile": {
"audio_load_ms": load_ms,
"audio_stage_wait_ms": audio_stage_wait_ms,
"audio_stage_slots": audio_stage_slots,
"audio_stage_inflight_peak": audio_stage_inflight_peak,
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
),
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
),
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
"prompt_semantic_batch_samples": float(
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
),
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
"ref_spec_ms": ref_spec_ms,
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
},
}
def extract_text_features(self, text: str, language: str, profile: dict | None = None):
return self.text_preprocessor.segment_and_extract_feature_for_text(
text, language, self.configs.version, profile=profile
)
def _set_ref_audio_path(self, ref_audio_path):
self.prompt_cache["ref_audio_path"] = ref_audio_path
def _set_ref_spec(self, ref_audio_path):
spec_audio = self._get_ref_spec(ref_audio_path)
if self.prompt_cache["refer_spec"] in [[], None]:
self.prompt_cache["refer_spec"] = [spec_audio]
else:
self.prompt_cache["refer_spec"][0] = spec_audio
def _get_ref_spec(self, ref_audio_path):
spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path)
self.prompt_cache["raw_audio"] = raw_audio
self.prompt_cache["raw_sr"] = raw_sr
return spec, audio
def _set_prompt_semantic(self, ref_wav_path: str):
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
dtype=np.float16 if self.configs.is_half else np.float32,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = wav16k.to(self.configs.device)
zero_wav_torch = zero_wav_torch.to(self.configs.device)
if self.configs.is_half:
wav16k = wav16k.half()
zero_wav_torch = zero_wav_torch.half()
wav16k = torch.cat([wav16k, zero_wav_torch])
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(
1, 2
) # .float()
codes = self.vits_model.extract_latent(hubert_feature)
prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic
prompt_semantic = self.extract_prompt_semantic(ref_wav_path)
self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
seq = sequences[0]
@ -1227,6 +1427,9 @@ class TTS:
###### inference ######
t_34 = 0.0
t_45 = 0.0
t2s_observe_batch_count = 0
t2s_observe_fastpath_hits = 0
t2s_observe_generated_tokens = 0
audio = []
is_first_package = True
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
@ -1280,6 +1483,29 @@ class TTS:
)
t4 = time.perf_counter()
t_34 += t4 - t3
if hasattr(self.t2s_model.model, "get_last_infer_stats"):
t2s_stats = self.t2s_model.model.get_last_infer_stats()
if t2s_stats:
generated_token_count = int(t2s_stats.get("generated_token_count", 0))
t2s_total_ms = (t4 - t3) * 1000.0
avg_decode_ms_per_token = (
t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0
)
t2s_observe_batch_count += 1
t2s_observe_generated_tokens += generated_token_count
if bool(t2s_stats.get("fastpath_hit", False)):
t2s_observe_fastpath_hits += 1
print(
"[t2s_observe] "
f"mode={t2s_stats.get('infer_mode')} "
f"batch_size={t2s_stats.get('batch_size')} "
f"tokens={generated_token_count} "
f"t2s_ms={t2s_total_ms:.3f} "
f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} "
f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} "
f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} "
f"fastpath_hit={t2s_stats.get('fastpath_hit')}"
)
batch_audio_fragment = []
@ -1500,6 +1726,18 @@ class TTS:
if not (return_fragment or streaming_mode):
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if t2s_observe_batch_count > 0:
request_avg_decode_ms_per_token = (
(t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0
)
print(
"[t2s_request_observe] "
f"batches={t2s_observe_batch_count} "
f"fastpath_hits={t2s_observe_fastpath_hits} "
f"generated_tokens={t2s_observe_generated_tokens} "
f"t2s_total_ms={t_34 * 1000.0:.3f} "
f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}"
)
if len(audio) == 0:
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
return
@ -1663,6 +1901,116 @@ class TTS:
return audio
def using_vocoder_synthesis_request_local(
self,
semantic_tokens: torch.Tensor,
phones: torch.Tensor,
prompt_semantic: torch.Tensor,
prompt_phones: torch.Tensor,
refer_audio_spec: torch.Tensor,
raw_audio: torch.Tensor,
raw_sr: int,
speed: float = 1.0,
sample_steps: int = 32,
):
prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device)
refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device)
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio = raw_audio.to(self.configs.device).float()
if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0)
tgt_sr = 24000 if self.configs.version == "v3" else 32000
if raw_sr != tgt_sr:
ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device)
mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4
mel2 = mel_spec_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
T_ref = self.vocoder_configs["T_ref"]
T_chunk = self.vocoder_configs["T_chunk"]
if T_min > T_ref:
mel2 = mel2[:, :, -T_ref:]
fea_ref = fea_ref[:, :, -T_ref:]
T_min = T_ref
chunk_len = T_chunk - T_min
mel2 = mel2.to(self.precision)
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
cfm_resss = []
idx = 0
while 1:
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
if fea_todo_chunk.shape[-1] == 0:
break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_res = self.vits_model.cfm.inference(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
)
cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cfm_res = torch.cat(cfm_resss, 2)
cfm_res = denorm_spec(cfm_res)
with torch.inference_mode():
wav_gen = self.vocoder(cfm_res)
audio = wav_gen[0][0]
return audio
@torch.inference_mode()
def synthesize_audio_request_local(
self,
semantic_tokens: torch.Tensor,
phones: torch.Tensor,
prompt_semantic: torch.Tensor,
prompt_phones: torch.Tensor,
refer_spec: tuple,
raw_audio: torch.Tensor,
raw_sr: int,
speed: float = 1.0,
sample_steps: int = 32,
):
refer_audio_spec, audio_tensor = refer_spec
if not self.configs.use_vocoder:
refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)]
sv_emb = None
if self.is_v2pro:
if audio_tensor is None:
raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频"))
sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device)
return self.vits_model.decode(
semantic_tokens,
phones,
refer_audio_spec_list,
speed=speed,
sv_emb=sv_emb,
).detach()[0, 0, :]
return self.using_vocoder_synthesis_request_local(
semantic_tokens=semantic_tokens,
phones=phones,
prompt_semantic=prompt_semantic,
prompt_phones=prompt_phones,
refer_audio_spec=refer_audio_spec,
raw_audio=raw_audio,
raw_sr=raw_sr,
speed=speed,
sample_steps=sample_steps,
)
def using_vocoder_synthesis_batched_infer(
self,
idx_list: List[int],

View File

@ -1,6 +1,8 @@
import os
import sys
import threading
import time
from contextlib import contextmanager
from tqdm import tqdm
@ -16,6 +18,7 @@ from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
from tools.i18n.i18n import I18nAuto, scan_language_list
@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list:
return result
class StageLimiter:
def __init__(self, slots: int):
self.slots = max(1, int(slots))
self.semaphore = threading.BoundedSemaphore(self.slots)
self.lock = threading.Lock()
self.inflight = 0
self.peak_inflight = 0
@contextmanager
def enter(self):
wait_start = time.perf_counter()
self.semaphore.acquire()
wait_ms = (time.perf_counter() - wait_start) * 1000.0
with self.lock:
self.inflight += 1
current_inflight = self.inflight
if current_inflight > self.peak_inflight:
self.peak_inflight = current_inflight
peak_inflight = self.peak_inflight
try:
yield {
"wait_ms": wait_ms,
"inflight": current_inflight,
"peak_inflight": peak_inflight,
"slots": self.slots,
}
finally:
with self.lock:
self.inflight = max(0, self.inflight - 1)
self.semaphore.release()
def snapshot(self) -> Dict[str, int]:
with self.lock:
return {
"slots": self.slots,
"inflight": self.inflight,
"peak_inflight": self.peak_inflight,
}
class TextPreprocessor:
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
def __init__(
self,
bert_model: AutoModelForMaskedLM,
tokenizer: AutoTokenizer,
device: torch.device,
bert_stage_limiter: StageLimiter | None = None,
bert_batch_worker: PrepareBertBatchWorker | None = None,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.bert_lock = threading.RLock()
self.bert_stage_limiter = bert_stage_limiter
self.bert_batch_worker = bert_batch_worker
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
@ -115,86 +166,136 @@ class TextPreprocessor:
return texts
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version)
return self.get_phones_and_bert(text, language, version, profile=profile)
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock:
text = re.sub(r' {2,}', ' ', text)
textlist = []
langlist = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
textlist = []
langlist = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text, "zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
)
if same_group:
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
else:
langlist.append(language)
textlist.append(tmp["text"])
return textlist, langlist
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True)
def get_phones_and_bert(
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
):
text = re.sub(r' {2,}', ' ', text)
textlist, langlist = self._split_text_by_language(text, language)
phones_list = []
bert_list = []
norm_text_list = []
for segment_text, segment_lang in zip(textlist, langlist):
phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
return phones, bert, norm_text
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile)
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
return phones, bert, norm_text
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
if profile is None:
return
profile[key] = float(profile.get(key, 0.0)) + float(value)
def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None:
if profile is None:
return
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
if self.bert_batch_worker is not None:
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
self._update_profile_peak(
profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)
)
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
if profile is not None:
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
return feature
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
if self.bert_stage_limiter is None:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.bert_stage_limiter.enter() as limiter_stats:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"])
self._accumulate_profile(profile, "bert_forward_ms", forward_ms)
self._accumulate_profile(profile, "bert_calls", 1.0)
self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"])
if profile is not None:
profile["bert_stage_slots"] = float(limiter_stats["slots"])
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
@ -209,10 +310,10 @@ class TextPreprocessor:
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None):
language = language.replace("all_", "")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
else:
feature = torch.zeros(
(1024, len(phones)),
@ -236,4 +337,4 @@ class TextPreprocessor:
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result
return result

View File

@ -1 +1,11 @@
from . import TTS, text_segmentation_method
from __future__ import annotations
import importlib
__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"]
def __getattr__(name: str):
if name in __all__:
return importlib.import_module(f"{__name__}.{name}")
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,197 @@
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple
import torch
@dataclass
class BertFeatureTask:
norm_text: str
word2ph: List[int]
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
done_event: threading.Event = field(default_factory=threading.Event)
result_feature: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareBertBatchWorker:
def __init__(
self,
bert_model,
tokenizer,
device,
stage_limiter=None,
batch_window_ms: int = 5,
max_batch_items: int = 16,
max_batch_tokens: int = 4096,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.stage_limiter = stage_limiter
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
self.max_batch_items = max(1, int(max_batch_items))
self.max_batch_tokens = max(16, int(max_batch_tokens))
self.condition = threading.Condition()
self.pending_tasks: Deque[BertFeatureTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.total_batches = 0
self.active_batch_size = 0
self.active_batch_peak = 0
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
self.worker_thread.start()
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
return max(1, len(task.norm_text) + 2)
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.result_feature is not None
return task.result_feature, dict(task.profile)
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
"pending": len(self.pending_tasks),
"pending_peak": self.pending_peak,
"total_submitted": self.total_submitted,
"total_finished": self.total_finished,
"total_batches": self.total_batches,
"active_batch_size": self.active_batch_size,
"active_batch_peak": self.active_batch_peak,
"batch_window_ms": int(self.batch_window_s * 1000.0),
"max_batch_items": self.max_batch_items,
"max_batch_tokens": self.max_batch_tokens,
}
def _collect_batch(self) -> List[BertFeatureTask]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
batch_tokens = self._estimate_task_tokens(batch[0])
deadline = time.perf_counter() + self.batch_window_s
while len(batch) < self.max_batch_items:
remaining = deadline - time.perf_counter()
if remaining <= 0:
break
if not self.pending_tasks:
self.condition.wait(timeout=remaining)
continue
next_task = self.pending_tasks[0]
next_tokens = self._estimate_task_tokens(next_task)
if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens:
break
batch.append(self.pending_tasks.popleft())
batch_tokens += next_tokens
self.active_batch_size = len(batch)
if self.active_batch_size > self.active_batch_peak:
self.active_batch_peak = self.active_batch_size
return batch
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
with self.condition:
self.active_batch_size = 0
self.total_batches += 1
self.total_finished += len(batch)
def _run_batch(self, batch: List[BertFeatureTask]) -> None:
batch_started = time.perf_counter()
texts = [task.norm_text for task in batch]
batch_tokens = sum(self._estimate_task_tokens(task) for task in batch)
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
if self.stage_limiter is None:
tokenize_start = time.perf_counter()
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
attention_mask_cpu = inputs["attention_mask"].cpu()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
forward_start = time.perf_counter()
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.stage_limiter.enter() as limiter_stats:
tokenize_start = time.perf_counter()
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
attention_mask_cpu = inputs["attention_mask"].cpu()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
forward_start = time.perf_counter()
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
hidden = outputs["hidden_states"][-3].detach().cpu()
scatter_start = time.perf_counter()
for batch_index, task in enumerate(batch):
try:
text_len = len(task.word2ph)
if text_len != len(task.norm_text):
raise AssertionError(
f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}"
)
seq_len = int(attention_mask_cpu[batch_index].sum().item())
char_features = hidden[batch_index, 1 : seq_len - 1]
if char_features.shape[0] != text_len:
raise AssertionError(
f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}"
)
phone_level_feature = []
for char_index, repeat_count in enumerate(task.word2ph):
phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1))
task.result_feature = torch.cat(phone_level_feature, dim=0).T
task.profile = {
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
"bert_forward_ms": float(forward_ms),
"bert_tokenize_ms": float(tokenize_ms),
"bert_scatter_ms": 0.0,
"bert_calls": 1.0,
"bert_stage_slots": float(limiter_stats["slots"]),
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"bert_batch_size": float(len(batch)),
"bert_batch_tokens": float(batch_tokens),
}
except Exception as exc: # noqa: PERF203
task.error = exc
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
for task in batch:
if task.result_feature is not None:
task.profile["bert_scatter_ms"] = float(scatter_ms)
task.done_event.set()
def _run_loop(self) -> None:
while True:
batch = self._collect_batch()
try:
self._run_batch(batch)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
task.done_event.set()
finally:
self._finalize_batch(batch)

View File

@ -0,0 +1,262 @@
import threading
import time
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Tuple
import librosa
import numpy as np
import torch
REF_AUDIO_MIN_SAMPLES_16K = 48000
REF_AUDIO_MAX_SAMPLES_16K = 160000
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
wav_mono = raw_audio
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
wav_mono = wav_mono.mean(0, keepdim=True)
wav16k = wav_mono.squeeze(0).cpu().numpy()
if raw_sr != 16000:
wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000)
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
raise OSError("参考音频在3~10秒范围外请更换")
wav16k = np.ascontiguousarray(wav16k, dtype=np.float32)
if zero_wav_samples > 0:
wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0)
return torch.from_numpy(wav16k)
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:
if conv1d is None:
return input_lengths.to(dtype=torch.long)
kernel_size = int(conv1d.kernel_size[0])
stride = int(conv1d.stride[0])
padding = int(conv1d.padding[0])
dilation = int(conv1d.dilation[0])
output_lengths = torch.div(
input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
stride,
rounding_mode="floor",
) + 1
return torch.clamp(output_lengths, min=0).to(dtype=torch.long)
@dataclass
class RefSemanticTask:
raw_audio: torch.Tensor
raw_sr: int
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
created_at: float = field(default_factory=time.perf_counter)
done_event: threading.Event = field(default_factory=threading.Event)
result_prompt_semantic: torch.Tensor | None = None
error: Exception | None = None
profile: Dict[str, float] = field(default_factory=dict)
class PrepareRefSemanticBatchWorker:
def __init__(
self,
ssl_model,
vits_model,
device,
is_half: bool,
zero_wav_samples: int,
stage_limiter=None,
batch_window_ms: int = 5,
max_batch_items: int = 8,
max_batch_samples: int = 960000,
):
self.ssl_model = ssl_model
self.vits_model = vits_model
self.device = device
self.is_half = bool(is_half)
self.zero_wav_samples = max(0, int(zero_wav_samples))
self.stage_limiter = stage_limiter
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
self.max_batch_items = max(1, int(max_batch_items))
self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples))
self.condition = threading.Condition()
self.pending_tasks: Deque[RefSemanticTask] = deque()
self.pending_peak = 0
self.total_submitted = 0
self.total_finished = 0
self.total_batches = 0
self.active_batch_size = 0
self.active_batch_peak = 0
self.active_batch_samples = 0
self.active_batch_samples_peak = 0
self.worker_thread = threading.Thread(
target=self._run_loop,
name="prepare-ref-semantic-batch-worker",
daemon=True,
)
self.worker_thread.start()
def _estimate_task_samples(self, task: RefSemanticTask) -> int:
raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0
base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr))))
return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples
def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr))
with self.condition:
self.pending_tasks.append(task)
self.total_submitted += 1
if len(self.pending_tasks) > self.pending_peak:
self.pending_peak = len(self.pending_tasks)
self.condition.notify_all()
task.done_event.wait()
if task.error is not None:
raise task.error
assert task.result_prompt_semantic is not None
return task.result_prompt_semantic, dict(task.profile)
def snapshot(self) -> Dict[str, int]:
with self.condition:
return {
"pending": len(self.pending_tasks),
"pending_peak": self.pending_peak,
"total_submitted": self.total_submitted,
"total_finished": self.total_finished,
"total_batches": self.total_batches,
"active_batch_size": self.active_batch_size,
"active_batch_peak": self.active_batch_peak,
"active_batch_samples": self.active_batch_samples,
"active_batch_samples_peak": self.active_batch_samples_peak,
"batch_window_ms": int(self.batch_window_s * 1000.0),
"max_batch_items": self.max_batch_items,
"max_batch_samples": self.max_batch_samples,
}
def _collect_batch(self) -> List[RefSemanticTask]:
with self.condition:
while not self.pending_tasks:
self.condition.wait()
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()]
batch_samples = self._estimate_task_samples(batch[0])
deadline = time.perf_counter() + self.batch_window_s
while len(batch) < self.max_batch_items:
remaining = deadline - time.perf_counter()
if remaining <= 0:
break
if not self.pending_tasks:
self.condition.wait(timeout=remaining)
continue
next_task = self.pending_tasks[0]
next_samples = self._estimate_task_samples(next_task)
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
break
batch.append(self.pending_tasks.popleft())
batch_samples += next_samples
self.active_batch_size = len(batch)
self.active_batch_samples = batch_samples
if self.active_batch_size > self.active_batch_peak:
self.active_batch_peak = self.active_batch_size
if self.active_batch_samples > self.active_batch_samples_peak:
self.active_batch_samples_peak = self.active_batch_samples
return batch
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
with self.condition:
self.active_batch_size = 0
self.active_batch_samples = 0
self.total_batches += 1
self.total_finished += len(batch)
def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor:
model = self.ssl_model.model
if hasattr(model, "_get_feature_vector_attention_mask"):
feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask)
return feature_mask.to(dtype=torch.long).sum(dim=1)
raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1)
if hasattr(model, "_get_feat_extract_output_lengths"):
return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long)
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
@torch.inference_mode()
def _run_batch(self, batch: List[RefSemanticTask]) -> None:
batch_started = time.perf_counter()
prepared_start = time.perf_counter()
prepared_wavs = [
prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch
]
cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0
wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long)
batch_samples = int(wav_lengths.sum().item())
max_wav_len = int(wav_lengths.max().item())
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
for batch_index, wav in enumerate(prepared_wavs):
wav_len = int(wav.shape[0])
input_values_cpu[batch_index, :wav_len] = wav
attention_mask_cpu[batch_index, :wav_len] = 1
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
if self.stage_limiter is None:
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
forward_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.stage_limiter.enter() as limiter_stats:
input_values = input_values_cpu.to(self.device)
attention_mask = attention_mask_cpu.to(self.device)
if self.is_half:
input_values = input_values.half()
forward_start = time.perf_counter()
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
codes = self.vits_model.extract_latent(hubert_feature)
forward_ms = (time.perf_counter() - forward_start) * 1000.0
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
scatter_start = time.perf_counter()
for batch_index, task in enumerate(batch):
try:
code_len = int(code_lengths[batch_index].item())
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
task.profile = {
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
"prompt_semantic_forward_ms": float(forward_ms),
"prompt_semantic_scatter_ms": 0.0,
"prompt_semantic_calls": 1.0,
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
"prompt_semantic_batch_size": float(len(batch)),
"prompt_semantic_batch_samples": float(batch_samples),
}
except Exception as exc: # noqa: PERF203
task.error = exc
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
for task in batch:
if task.result_prompt_semantic is not None:
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
task.done_event.set()
def _run_loop(self) -> None:
while True:
batch = self._collect_batch()
try:
self._run_batch(batch)
except Exception as exc: # noqa: PERF203
for task in batch:
task.error = exc
task.done_event.set()
finally:
self._finalize_batch(batch)

View File

@ -0,0 +1,734 @@
from __future__ import annotations
from concurrent.futures import Future
from dataclasses import dataclass
from pathlib import Path
import time
from typing import Any, Dict, List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from AR.models.utils import make_pad_mask_left, sample
def _sync_device(device: Any) -> None:
try:
device_str = str(device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
@dataclass
class SchedulerRequestSpec:
request_id: str
ref_audio_path: Path
prompt_text: str
prompt_lang: str
text: str
text_lang: str
top_k: int
top_p: float
temperature: float
repetition_penalty: float
early_stop_num: int
ready_step: int = 0
@dataclass
class T2SRequestState:
request_id: str
ref_audio_path: Path
prompt_text: str
prompt_lang: str
text: str
text_lang: str
norm_prompt_text: str
norm_text: str
phones: torch.LongTensor
prompt_phones: torch.LongTensor
all_phones: torch.LongTensor
all_bert_features: torch.Tensor
prompt_semantic: torch.LongTensor
refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]]
raw_audio: torch.Tensor
raw_sr: int
top_k: int
top_p: float
temperature: float
repetition_penalty: float
early_stop_num: int
ready_step: int
prepare_profile: Dict[str, float]
@dataclass
class T2SRunningRequest:
state: T2SRequestState
y_sequence: torch.LongTensor
prefix_len: int
decode_attn_mask: Optional[torch.Tensor]
k_cache: List[torch.Tensor]
v_cache: List[torch.Tensor]
step_idx: int
@dataclass
class T2SFinishedItem:
request_id: str
semantic_tokens: torch.LongTensor
finish_idx: int
finish_reason: str
@dataclass
class T2SActiveBatch:
request_ids: List[str]
states: List[T2SRequestState]
x: torch.Tensor
x_lens: torch.LongTensor
y_sequences: List[torch.LongTensor]
prefix_lens: torch.LongTensor
xy_pos: torch.Tensor
key_padding_mask: torch.Tensor
prefill_attn_mask: torch.Tensor
decode_attn_mask: Optional[torch.Tensor]
k_cache: Optional[List[torch.Tensor]]
v_cache: Optional[List[torch.Tensor]]
step_idx: int
prefill_done: bool
def normalize_sentence(text: str, language: str) -> str:
text = text.strip("\n").strip()
if not text:
return text
if text[-1] not in {",", ".", "?", "!", "", "", "", "", "", "", ";", ":"}:
text += "" if language != "en" else "."
return text
@torch.inference_mode()
def prepare_request_state(
tts: Any,
spec: SchedulerRequestSpec,
) -> T2SRequestState:
device = tts.configs.device
prepare_start = time.perf_counter()
_sync_device(device)
prepare_sync_start = time.perf_counter()
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
text = spec.text.strip("\n")
prompt_text_profile: Dict[str, float] = {}
text_features_profile: Dict[str, float] = {}
text_feature_pair_start = time.perf_counter()
prompt_future: Future | None = None
def _extract_prompt_features():
_sync_device(device)
prompt_start = time.perf_counter()
result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile)
_sync_device(device)
return result, (time.perf_counter() - prompt_start) * 1000.0
if getattr(tts, "prepare_text_cpu_executor", None) is not None:
prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features)
_sync_device(device)
text_features_start = time.perf_counter()
phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile)
_sync_device(device)
text_features_ms = (time.perf_counter() - text_features_start) * 1000.0
if prompt_future is None:
_sync_device(device)
prompt_text_features_start = time.perf_counter()
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(
prompt_text, spec.prompt_lang, profile=prompt_text_profile
)
_sync_device(device)
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
prompt_text_profile["parallel_future_wait_ms"] = 0.0
else:
prompt_wait_start = time.perf_counter()
(prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result()
prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0
text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0
if phones is None:
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
_sync_device(device)
ref_audio_bundle_start = time.perf_counter()
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
raw_audio = ref_audio_bundle["raw_audio"]
raw_sr = int(ref_audio_bundle["raw_sr"])
_sync_device(device)
ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0
bundle_profile = ref_audio_bundle.get("profile", {})
prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0))
audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0))
_sync_device(device)
tensorize_start = time.perf_counter()
phones_tensor = torch.LongTensor(phones).to(tts.configs.device)
prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device)
all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device)
all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to(
dtype=tts.precision, device=tts.configs.device
)
_sync_device(device)
tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0
_sync_device(device)
prepare_profile = {
"prompt_text_features_ms": prompt_text_features_ms,
"text_features_ms": text_features_ms,
"prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)),
"prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)),
"prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)),
"prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)),
"prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)),
"prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)),
"prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)),
"prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)),
"prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)),
"prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)),
"text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)),
"text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)),
"text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)),
"text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)),
"text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)),
"text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)),
"text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)),
"text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)),
"text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)),
"text_feature_pair_ms": text_feature_pair_ms,
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
"audio_load_ms": audio_load_ms,
"audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)),
"audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)),
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
"prompt_semantic_ms": prompt_semantic_ms,
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
"prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)),
"prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
"prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)),
"prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)),
"ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)),
"ref_spec_ms": ref_spec_ms,
"ref_audio_bundle_ms": ref_audio_bundle_ms,
"tensorize_ms": tensorize_ms,
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
"wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0,
}
return T2SRequestState(
request_id=spec.request_id,
ref_audio_path=spec.ref_audio_path,
prompt_text=prompt_text,
prompt_lang=spec.prompt_lang,
text=text,
text_lang=spec.text_lang,
norm_prompt_text=prompt_norm_text,
norm_text=norm_text,
phones=phones_tensor,
prompt_phones=prompt_phones_tensor,
all_phones=all_phones,
all_bert_features=all_bert_features,
prompt_semantic=prompt_semantic,
refer_spec=(spec_audio, audio_16k),
raw_audio=raw_audio,
raw_sr=raw_sr,
top_k=spec.top_k,
top_p=spec.top_p,
temperature=spec.temperature,
repetition_penalty=spec.repetition_penalty,
early_stop_num=spec.early_stop_num,
ready_step=spec.ready_step,
prepare_profile=prepare_profile,
)
def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor:
if hidden.shape[0] >= target_len:
return hidden
return F.pad(hidden, (0, 0, target_len - hidden.shape[0], 0), value=0)
def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: torch.device) -> None:
required_len = max_position + 1
if model.ar_audio_position.pe is not None and model.ar_audio_position.pe.size(1) >= required_len:
if model.ar_audio_position.pe.dtype != dtype or model.ar_audio_position.pe.device != device:
model.ar_audio_position.pe = model.ar_audio_position.pe.to(dtype=dtype, device=device)
return
model.ar_audio_position.extend_pe(
torch.zeros(1, required_len, model.ar_audio_position.embedding_dim, device=device, dtype=dtype)
)
@torch.inference_mode()
def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch:
x_items: List[torch.Tensor] = []
y_pos_items: List[torch.Tensor] = []
x_lens: List[int] = []
prefix_lens: List[int] = []
y_sequences: List[torch.LongTensor] = []
for state in states:
text_emb = model.ar_text_embedding(state.all_phones.unsqueeze(0))
bert_proj = model.bert_proj(state.all_bert_features.transpose(0, 1).unsqueeze(0))
x_pos = model.ar_text_position(text_emb + bert_proj).squeeze(0)
y_emb = model.ar_audio_embedding(state.prompt_semantic.unsqueeze(0))
y_pos = model.ar_audio_position(y_emb).squeeze(0)
x_items.append(x_pos)
y_pos_items.append(y_pos)
x_lens.append(x_pos.shape[0])
prefix_lens.append(y_pos.shape[0])
y_sequences.append(state.prompt_semantic.clone())
max_x_len = max(x_lens)
max_prefix_len = max(prefix_lens)
x_batch = torch.stack([_left_pad_hidden(item, max_x_len) for item in x_items], dim=0)
y_pos_batch = torch.stack([_left_pad_hidden(item, max_prefix_len) for item in y_pos_items], dim=0)
xy_pos = torch.cat([x_batch, y_pos_batch], dim=1)
device = x_batch.device
x_lens_tensor = torch.LongTensor(x_lens).to(device)
prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device)
src_len = max_x_len + max_prefix_len
x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len)
y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len)
key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool()
x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True)
y_mask = F.pad(
torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1),
(max_x_len, 0),
value=False,
)
causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0)
attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1)
return T2SActiveBatch(
request_ids=[state.request_id for state in states],
states=list(states),
x=x_batch,
x_lens=x_lens_tensor,
y_sequences=y_sequences,
prefix_lens=prefix_lens_tensor,
xy_pos=xy_pos,
key_padding_mask=key_padding_mask,
prefill_attn_mask=attn_mask,
decode_attn_mask=None,
k_cache=None,
v_cache=None,
step_idx=0,
prefill_done=False,
)
def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> torch.Tensor:
last_tokens = torch.stack([seq[-1:] for seq in y_sequences], dim=0)
y_emb = model.ar_audio_embedding(last_tokens)
position_ids = torch.LongTensor([int(seq.shape[0] - 1) for seq in y_sequences]).to(y_emb.device)
_ensure_audio_pe(model, int(position_ids.max().item()), y_emb.dtype, y_emb.device)
pos_emb = model.ar_audio_position.pe[0].index_select(0, position_ids).unsqueeze(1)
return y_emb * model.ar_audio_position.x_scale + model.ar_audio_position.alpha * pos_emb.to(
dtype=y_emb.dtype, device=y_emb.device
)
def _sample_per_request(
model: Any,
active_batch: T2SActiveBatch,
logits: torch.Tensor,
max_steps: int,
) -> Tuple[List[T2SFinishedItem], List[int], List[torch.LongTensor]]:
finished_items: List[T2SFinishedItem] = []
keep_indices: List[int] = []
updated_sequences: List[torch.LongTensor] = []
step_idx = active_batch.step_idx
for batch_index, state in enumerate(active_batch.states):
logits_i = logits[batch_index : batch_index + 1].clone()
current_history = active_batch.y_sequences[batch_index]
sampled = sample(
logits_i,
current_history.unsqueeze(0),
top_k=state.top_k,
top_p=state.top_p,
repetition_penalty=state.repetition_penalty,
temperature=state.temperature,
)[0]
sampled_token = int(sampled[0, 0].item())
argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item())
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
finish_reason: Optional[str] = None
if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num:
finish_reason = "early_stop"
elif step_idx + 1 >= max_steps:
finish_reason = "max_step"
elif sampled_token == model.EOS:
finish_reason = "eos_sample"
elif argmax_token == model.EOS:
finish_reason = "eos_argmax"
if finish_reason is not None:
prefix_len = int(active_batch.prefix_lens[batch_index].item())
finished_items.append(
T2SFinishedItem(
request_id=state.request_id,
semantic_tokens=new_history[prefix_len:-1].clone(),
finish_idx=step_idx,
finish_reason=finish_reason,
)
)
else:
keep_indices.append(batch_index)
updated_sequences.append(new_history)
return finished_items, keep_indices, updated_sequences
def decode_one_step(
model: Any,
active_batch: T2SActiveBatch,
max_steps: int,
) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]:
if not active_batch.prefill_done:
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt(
active_batch.xy_pos, active_batch.prefill_attn_mask, None
)
active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
active_batch.prefill_done = True
else:
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token(
active_batch.xy_pos,
active_batch.k_cache,
active_batch.v_cache,
active_batch.decode_attn_mask,
)
if active_batch.decode_attn_mask is not None:
active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False)
logits = model.ar_predict_layer(xy_dec[:, -1])
if active_batch.step_idx < 11:
logits = logits[:, :-1]
finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps)
if len(keep_indices) == 0:
return None, finished_items
device = logits.device
keep_tensor = torch.LongTensor(keep_indices).to(device)
active_batch.request_ids = [active_batch.request_ids[i] for i in keep_indices]
active_batch.states = [active_batch.states[i] for i in keep_indices]
active_batch.y_sequences = updated_sequences
active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor)
if active_batch.decode_attn_mask is not None:
active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor)
if active_batch.k_cache is not None and active_batch.v_cache is not None:
for cache_index in range(len(active_batch.k_cache)):
active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor)
active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor)
active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences)
active_batch.step_idx += 1
return active_batch, finished_items
def run_scheduler_batch(
model: Any,
states: Sequence[T2SRequestState],
max_steps: int,
) -> List[T2SFinishedItem]:
return run_scheduler_continuous(model, states, max_steps=max_steps)
def _pad_cache_left(cache: torch.Tensor, target_len: int) -> torch.Tensor:
pad_len = target_len - cache.shape[1]
if pad_len <= 0:
return cache
return F.pad(cache, (0, 0, pad_len, 0), value=0)
def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor:
pad_len = target_len - mask.shape[-1]
if pad_len <= 0:
return mask
return F.pad(mask, (pad_len, 0), value=True)
def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor:
if mask.shape[-1] > target_len:
return mask[:, :, :, -target_len:]
if mask.shape[-1] < target_len:
return _pad_decode_mask_left(mask, target_len)
return mask
def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor:
expected_mask_len = running_request.k_cache[0].shape[1] + 1
if running_request.decode_attn_mask is not None:
return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len)
current_mask_len = running_request.k_cache[0].shape[1] + 1
return torch.zeros(
(1, 1, 1, current_mask_len),
dtype=torch.bool,
device=running_request.k_cache[0].device,
)
@torch.inference_mode()
def run_prefill_step(
model: Any,
states: Sequence[T2SRequestState],
max_steps: int,
) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]:
if not states:
return [], []
active_batch = build_prefill_batch(model, states)
xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None)
decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
if len(states) == 1 and not decode_attn_mask.any().item():
decode_attn_mask = None
logits = model.ar_predict_layer(xy_dec[:, -1])
running_requests: List[T2SRunningRequest] = []
finished_items: List[T2SFinishedItem] = []
for batch_index, state in enumerate(states):
logits_i = logits[batch_index : batch_index + 1].clone()
if 0 < 11:
logits_i = logits_i[:, :-1]
current_history = active_batch.y_sequences[batch_index]
sampled = sample(
logits_i,
current_history.unsqueeze(0),
top_k=state.top_k,
top_p=state.top_p,
repetition_penalty=state.repetition_penalty,
temperature=state.temperature,
)[0]
sampled_token = int(sampled[0, 0].item())
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
prefix_len = int(active_batch.prefix_lens[batch_index].item())
finish_reason: Optional[str] = None
if state.early_stop_num != -1 and (new_history.shape[0] - prefix_len) > state.early_stop_num:
finish_reason = "early_stop"
elif 1 >= max_steps:
finish_reason = "max_step"
elif sampled_token == model.EOS:
finish_reason = "eos_sample"
elif argmax_token == model.EOS:
finish_reason = "eos_argmax"
if finish_reason is not None:
finished_items.append(
T2SFinishedItem(
request_id=state.request_id,
semantic_tokens=new_history[prefix_len:-1].clone(),
finish_idx=0,
finish_reason=finish_reason,
)
)
continue
real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len
request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache]
request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache]
request_decode_attn_mask = None
if decode_attn_mask is not None:
request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone()
request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1)
if not request_decode_attn_mask.any().item():
request_decode_attn_mask = None
running_requests.append(
T2SRunningRequest(
state=state,
y_sequence=new_history,
prefix_len=prefix_len,
decode_attn_mask=request_decode_attn_mask,
k_cache=request_k_cache,
v_cache=request_v_cache,
step_idx=1,
)
)
return running_requests, finished_items
def _build_decode_batch_from_running(
model: Any,
running_requests: Sequence[T2SRunningRequest],
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]:
xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests])
max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests)
num_layers = len(running_requests[0].k_cache)
batched_k_cache: List[torch.Tensor] = []
batched_v_cache: List[torch.Tensor] = []
for layer_index in range(num_layers):
batched_k_cache.append(
torch.cat([_pad_cache_left(item.k_cache[layer_index], max_kv_len) for item in running_requests], dim=0)
)
batched_v_cache.append(
torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0)
)
if all(item.decode_attn_mask is None for item in running_requests):
batched_decode_attn_mask = None
else:
materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests]
max_mask_len = max(mask.shape[-1] for mask in materialized_masks)
batched_decode_attn_mask = torch.cat(
[_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks],
dim=0,
)
return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask
@torch.inference_mode()
def run_decode_step_for_running(
model: Any,
running_requests: Sequence[T2SRunningRequest],
max_steps: int,
) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]:
if not running_requests:
return [], []
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = _build_decode_batch_from_running(
model, running_requests
)
xy_dec, next_k_cache, next_v_cache = model.t2s_transformer.decode_next_token(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
)
logits = model.ar_predict_layer(xy_dec[:, -1])
next_running: List[T2SRunningRequest] = []
finished_items: List[T2SFinishedItem] = []
for batch_index, running_request in enumerate(running_requests):
current_idx = running_request.step_idx
logits_i = logits[batch_index : batch_index + 1].clone()
if current_idx < 11:
logits_i = logits_i[:, :-1]
sampled = sample(
logits_i,
running_request.y_sequence.unsqueeze(0),
top_k=running_request.state.top_k,
top_p=running_request.state.top_p,
repetition_penalty=running_request.state.repetition_penalty,
temperature=running_request.state.temperature,
)[0]
sampled_token = int(sampled[0, 0].item())
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0)
finish_reason: Optional[str] = None
if running_request.state.early_stop_num != -1 and (new_history.shape[0] - running_request.prefix_len) > running_request.state.early_stop_num:
finish_reason = "early_stop"
elif current_idx + 1 >= max_steps:
finish_reason = "max_step"
elif sampled_token == model.EOS:
finish_reason = "eos_sample"
elif argmax_token == model.EOS:
finish_reason = "eos_argmax"
if finish_reason is not None:
finished_items.append(
T2SFinishedItem(
request_id=running_request.state.request_id,
semantic_tokens=new_history[running_request.prefix_len:-1].clone(),
finish_idx=current_idx,
finish_reason=finish_reason,
)
)
continue
real_next_kv_len = running_request.k_cache[0].shape[1] + 1
request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache]
request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache]
if batched_decode_attn_mask is None:
next_decode_attn_mask = None
else:
current_decode_mask_len = running_request.k_cache[0].shape[1] + 1
current_decode_attn_mask = batched_decode_attn_mask[
batch_index : batch_index + 1, :, :, -current_decode_mask_len:
]
next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False)
next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1)
if not next_decode_attn_mask.any().item():
next_decode_attn_mask = None
next_running.append(
T2SRunningRequest(
state=running_request.state,
y_sequence=new_history,
prefix_len=running_request.prefix_len,
decode_attn_mask=next_decode_attn_mask,
k_cache=request_k_cache,
v_cache=request_v_cache,
step_idx=current_idx + 1,
)
)
return next_running, finished_items
@torch.inference_mode()
def run_scheduler_continuous(
model: Any,
states: Sequence[T2SRequestState],
max_steps: int,
) -> List[T2SFinishedItem]:
pending = sorted(states, key=lambda item: (item.ready_step, item.request_id))
running_requests: List[T2SRunningRequest] = []
finished: List[T2SFinishedItem] = []
current_tick = 0
while pending or running_requests:
admitted: List[T2SRequestState] = []
while pending and pending[0].ready_step <= current_tick:
admitted.append(pending.pop(0))
admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps)
finished.extend(admitted_finished)
if running_requests:
running_requests, step_finished = run_decode_step_for_running(
model,
running_requests,
max_steps=max_steps,
)
finished.extend(step_finished)
running_requests.extend(admitted_running)
if not running_requests and pending:
current_tick = max(current_tick + 1, pending[0].ready_step)
continue
current_tick += 1
finished.sort(key=lambda item: item.request_id)
return finished

View File

@ -180,10 +180,15 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
def _g2p(segments):
phones_list = []
word2ph = []
for seg in segments:
g2pw_batch_results = []
g2pw_batch_cursor = 0
processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments]
if is_g2pw:
batch_inputs = [seg for seg in processed_segments if seg]
g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else []
for seg in processed_segments:
pinyins = []
# Replace all English words in the sentence
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
initials = []
@ -204,8 +209,10 @@ def _g2p(segments):
finals = sum(finals, [])
print("pypinyin结果", initials, finals)
else:
# g2pw采用整句推理
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
# g2pw采用整句推理批量推理逐句取结果
if seg:
pinyins = g2pw_batch_results[g2pw_batch_cursor]
g2pw_batch_cursor += 1
pre_word_length = 0
for word, pos in seg_cut:

View File

@ -18,6 +18,7 @@ Credits
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
@ -37,6 +38,8 @@ def prepare_onnx_input(
use_mask: bool = False,
window_size: int = None,
max_len: int = 512,
char2id: Optional[Dict[str, int]] = None,
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
@ -48,33 +51,88 @@ def prepare_onnx_input(
phoneme_masks = []
char_ids = []
position_ids = []
tokenized_cache = {}
if char2id is None:
char2id = {char: idx for idx, char in enumerate(chars)}
if use_mask:
if char_phoneme_masks is None:
char_phoneme_masks = {
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
for char in char2phonemes
}
else:
full_phoneme_mask = [1] * len(labels)
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
cached = tokenized_cache.get(text)
if cached is None:
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
)
if len(tokens) <= max_len - 2:
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
cached = {
"is_short": True,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
"input_id": shared_input_id,
"token_type_id": shared_token_type_id,
"attention_mask": shared_attention_mask,
}
else:
cached = {
"is_short": False,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
}
tokenized_cache[text] = cached
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
if cached["is_short"]:
text_for_query = text
query_id_for_query = query_id
text2token_for_query = cached["text2token"]
input_id = cached["input_id"]
token_type_id = cached["token_type_id"]
attention_mask = cached["attention_mask"]
else:
(
text_for_query,
query_id_for_query,
tokens_for_query,
text2token_for_query,
_token2text_for_query,
) = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=cached["tokens"],
text2token=cached["text2token"],
token2text=cached["token2text"],
)
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = (
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
)
char_id = chars.index(query_char)
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
query_char = text_for_query[query_id_for_query]
if use_mask:
phoneme_mask = char_phoneme_masks[query_char]
else:
phoneme_mask = full_phoneme_mask
char_id = char2id[query_char]
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
@ -83,10 +141,15 @@ def prepare_onnx_input(
char_ids.append(char_id)
position_ids.append(position_id)
max_token_length = max(len(seq) for seq in input_ids)
def _pad_sequences(sequences, pad_value=0):
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
outputs = {
"input_ids": np.array(input_ids).astype(np.int64),
"token_type_ids": np.array(token_type_ids).astype(np.int64),
"attention_masks": np.array(attention_masks).astype(np.int64),
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
"char_ids": np.array(char_ids).astype(np.int64),
"position_ids": np.array(position_ids).astype(np.int64),

View File

@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple
import numpy as np
import onnxruntime
import requests
import torch
from opencc import OpenCC
from pypinyin import Style, pinyin
from transformers.models.auto.tokenization_auto import AutoTokenizer
@ -22,9 +21,8 @@ from .utils import load_config
onnxruntime.set_default_logger_severity(3)
try:
onnxruntime.preload_dlls()
except:
except Exception:
pass
# traceback.print_exc()
warnings.filterwarnings("ignore")
model_version = "1.1"
@ -55,6 +53,24 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis
return all_preds, all_confidences
def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]:
for candidate_dir in candidate_dirs:
if not candidate_dir:
continue
json_path = os.path.join(candidate_dir, filename)
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as fr:
return json.load(fr)
raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}")
def _find_first_existing_file(*paths: str) -> str:
for path in paths:
if path and os.path.exists(path):
return path
raise FileNotFoundError(f"Files not found: {paths}")
def download_and_decompress(model_dir: str = "G2PWModel/"):
if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir)
@ -62,7 +78,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
print("Downloading g2pw model...")
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"
with requests.get(modelscope_url, stream=True) as r:
r.raise_for_status()
with open(zip_dir, "wb") as f:
@ -79,7 +95,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
return model_dir
class G2PWOnnxConverter:
class _G2PWBaseOnnxConverter:
def __init__(
self,
model_dir: str = "G2PWModel/",
@ -87,33 +103,16 @@ class G2PWOnnxConverter:
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
uncompress_path = download_and_decompress(model_dir)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
else:
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
self.model_dir = download_and_decompress(model_dir)
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
self.polyphonic_chars = [
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
]
@ -149,31 +148,47 @@ class G2PWOnnxConverter:
)
self.chars = sorted(list(self.char2phonemes.keys()))
self.char2id = {char: idx for idx, char in enumerate(self.chars)}
self.char_phoneme_masks = (
{
char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))]
for char in self.char2phonemes
}
if self.config.use_mask
else None
)
self.polyphonic_chars_new = set(self.chars)
for char in self.non_polyphonic:
if char in self.polyphonic_chars_new:
self.polyphonic_chars_new.remove(char)
self.polyphonic_chars_new.discard(char)
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
for char in self.non_monophonic:
if char in self.monophonic_chars_dict:
self.monophonic_chars_dict.pop(char)
self.monophonic_chars_dict.pop(char, None)
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel"))
candidate_asset_dirs = [self.model_dir, default_asset_dir]
self.bopomofo_convert_dict = _load_json_from_candidates(
"bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs
)
self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs)
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
"bopomofo": lambda x: x,
"pinyin": self._convert_bopomofo_to_pinyin,
}[style]
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC("s2tw")
self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in {
"1",
"true",
"yes",
"y",
"on",
}
# 聚焦到多音字附近上下文默认左右各16字设为0表示关闭裁剪整句
self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16")))
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
@ -181,9 +196,8 @@ class G2PWOnnxConverter:
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
else:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
def __call__(self, sentences: List[str]) -> List[List[str]]:
if isinstance(sentences, str):
@ -197,51 +211,147 @@ class G2PWOnnxConverter:
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
onnx_input = prepare_onnx_input(
model_input = prepare_onnx_input(
tokenizer=self.tokenizer,
labels=self.labels,
char2phonemes=self.char2phonemes,
chars=self.chars,
texts=texts,
query_ids=query_ids,
query_ids=model_query_ids,
use_mask=self.config.use_mask,
window_size=None,
char2id=self.char2id,
char_phoneme_masks=self.char_phoneme_masks,
)
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
if not model_input:
return partial_results
if self.enable_sentence_dedup:
preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts)
else:
preds, _confidences = self._predict(model_input=model_input)
if self.config.use_char_phoneme:
preds = [pred.split(" ")[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
results[sent_id][query_id] = self.style_convert_func(pred)
return results
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]:
texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent)
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent)
polyphonic_indices: List[int] = []
for i, char in enumerate(sent):
if char in self.polyphonic_chars_new:
texts.append(sent)
query_ids.append(i)
sent_ids.append(sent_id)
polyphonic_indices.append(i)
elif char in self.monophonic_chars_dict:
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
else:
partial_result[i] = pypinyin_result[i][0]
if polyphonic_indices:
if self.polyphonic_context_chars > 0:
left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars)
right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1)
sent_for_predict = sent[left:right]
query_offset = left
else:
sent_for_predict = sent
query_offset = 0
for index in polyphonic_indices:
texts.append(sent_for_predict)
model_query_ids.append(index - query_offset)
result_query_ids.append(index)
sent_ids.append(sent_id)
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results
return texts, model_query_ids, result_query_ids, sent_ids, partial_results
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
raise NotImplementedError
def _predict_with_sentence_dedup(
self, model_input: Dict[str, Any], texts: List[str]
) -> Tuple[List[str], List[float]]:
if len(texts) <= 1:
return self._predict(model_input=model_input)
grouped_indices: Dict[str, List[int]] = {}
for idx, text in enumerate(texts):
grouped_indices.setdefault(text, []).append(idx)
if all(len(indices) == 1 for indices in grouped_indices.values()):
return self._predict(model_input=model_input)
preds: List[str] = [""] * len(texts)
confidences: List[float] = [0.0] * len(texts)
for indices in grouped_indices.values():
group_input = {name: value[indices] for name, value in model_input.items()}
if len(indices) > 1:
for name in ("input_ids", "token_type_ids", "attention_masks"):
group_input[name] = group_input[name][:1]
group_preds, group_confidences = self._predict(model_input=group_input)
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
preds[output_idx] = pred
confidences[output_idx] = confidence
return preds, confidences
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
def __init__(
self,
model_dir: str = "G2PWModel/",
style: str = "bopomofo",
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
super().__init__(
model_dir=model_dir,
style=style,
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
onnx_path = _find_first_existing_file(
os.path.join(self.model_dir, "g2pW.onnx"),
os.path.join(self.model_dir, "g2pw.onnx"),
)
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pw = onnxruntime.InferenceSession(
onnx_path,
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
else:
self.session_g2pw = onnxruntime.InferenceSession(
onnx_path,
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels)

1228
api_v3.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import asyncio
import json
import subprocess
import threading
import time
import wave
from pathlib import Path
from typing import Any, Dict, List, Optional
import httpx
ROOT_DIR = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.")
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880")
parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit")
parser.add_argument("--concurrency", type=int, required=True)
parser.add_argument("--timeout-sec", type=float, default=120.0)
parser.add_argument("--server-pid", type=int, default=None)
parser.add_argument("--poll-interval-sec", type=float, default=0.1)
parser.add_argument("--text-lang", type=str, default="zh")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--media-type", type=str, default="wav")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--sample-steps", type=int, default=32)
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav")
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench")
return parser.parse_args()
def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]:
wav_paths_all = sorted(args.wav_dir.glob("*.wav"))
wav_paths: List[Path] = []
for wav_path in wav_paths_all:
with wave.open(str(wav_path), "rb") as handle:
duration = handle.getnframes() / float(handle.getframerate())
if 3.0 <= duration <= 10.0:
wav_paths.append(wav_path)
if not wav_paths:
raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}")
text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if not text_lines:
raise ValueError(f"没有找到有效文本行: {args.text_file}")
requests: List[Dict[str, Any]] = []
for index in range(args.concurrency):
wav_path = wav_paths[index % len(wav_paths)]
lab_path = wav_path.with_suffix(".lab")
if not lab_path.exists():
raise FileNotFoundError(f"缺少参考文本: {lab_path}")
requests.append(
{
"request_id": f"bench_{args.concurrency:03d}_{index:03d}",
"text": text_lines[index % len(text_lines)],
"text_lang": args.text_lang,
"ref_audio_path": str(wav_path),
"prompt_lang": args.prompt_lang,
"prompt_text": lab_path.read_text(encoding="utf-8").strip(),
"top_k": int(args.top_k),
"top_p": float(args.top_p),
"temperature": float(args.temperature),
"repetition_penalty": float(args.repetition_penalty),
"sample_steps": int(args.sample_steps),
"media_type": args.media_type,
"timeout_sec": float(args.timeout_sec),
}
)
return requests
class GpuMemoryPoller:
def __init__(self, server_pid: Optional[int], interval_sec: float):
self.server_pid = server_pid
self.interval_sec = interval_sec
self._stop = threading.Event()
self.samples: List[Dict[str, Any]] = []
self.thread: Optional[threading.Thread] = None
def _query_memory_mb(self) -> Optional[int]:
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-compute-apps=pid,used_gpu_memory",
"--format=csv,noheader,nounits",
],
check=True,
capture_output=True,
text=True,
)
except Exception:
return None
total = 0
found = False
for line in result.stdout.splitlines():
line = line.strip()
if not line:
continue
parts = [item.strip() for item in line.split(",")]
if len(parts) != 2:
continue
try:
pid = int(parts[0])
used_mb = int(parts[1])
except ValueError:
continue
if self.server_pid is None or pid == self.server_pid:
total += used_mb
found = True
if self.server_pid is None:
return total
return total if found else 0
def _run(self) -> None:
while not self._stop.is_set():
used_mb = self._query_memory_mb()
self.samples.append({"ts": time.time(), "used_mb": used_mb})
self._stop.wait(self.interval_sec)
def start(self) -> None:
self.thread = threading.Thread(target=self._run, daemon=True)
self.thread.start()
def stop(self) -> None:
self._stop.set()
if self.thread is not None:
self.thread.join(timeout=2.0)
def summary(self) -> Dict[str, Any]:
valid = [item for item in self.samples if item["used_mb"] is not None]
peak = max(valid, key=lambda item: item["used_mb"]) if valid else None
first = valid[0] if valid else None
last = valid[-1] if valid else None
return {
"server_pid": self.server_pid,
"sample_count": int(len(self.samples)),
"start_used_mb": None if first is None else int(first["used_mb"]),
"peak_used_mb": None if peak is None else int(peak["used_mb"]),
"peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]),
"end_used_mb": None if last is None else int(last["used_mb"]),
"peak_ts": None if peak is None else float(peak["ts"]),
"samples": self.samples,
}
async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
started = time.perf_counter()
try:
response = await client.post(url, json=payload)
elapsed_ms = (time.perf_counter() - started) * 1000.0
item = {
"request_id": payload["request_id"],
"status_code": int(response.status_code),
"elapsed_ms": float(elapsed_ms),
"content_type": response.headers.get("content-type"),
"audio_bytes": int(len(response.content)),
"headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")},
}
if response.status_code != 200:
try:
item["error_body"] = response.json()
except Exception:
item["error_body"] = response.text
return item
except Exception as exc:
return {
"request_id": payload["request_id"],
"status_code": -1,
"elapsed_ms": float((time.perf_counter() - started) * 1000.0),
"exception": repr(exc),
}
async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
payloads = load_requests(args)
url = args.base_url.rstrip("/") + args.endpoint
poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec)
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0)
started = time.perf_counter()
poller.start()
try:
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads])
finally:
poller.stop()
wall_ms = (time.perf_counter() - started) * 1000.0
ok_results = [item for item in results if item["status_code"] == 200]
failed_results = [item for item in results if item["status_code"] != 200]
request_total_ms = []
worker_total_ms = []
for item in ok_results:
headers = item.get("headers", {})
if "x-request-total-ms" in headers:
request_total_ms.append(float(headers["x-request-total-ms"]))
if "x-worker-total-ms" in headers:
worker_total_ms.append(float(headers["x-worker-total-ms"]))
return {
"concurrency": int(args.concurrency),
"server_pid": args.server_pid,
"request_count": int(len(payloads)),
"wall_ms": float(wall_ms),
"success_count": int(len(ok_results)),
"failure_count": int(len(failed_results)),
"request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None,
"request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None,
"worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None,
"worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None,
"gpu_memory": poller.summary(),
"results": results,
}
def main() -> None:
args = parse_args()
output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}"
output_dir.mkdir(parents=True, exist_ok=True)
summary = asyncio.run(run_benchmark(args))
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps({
"concurrency": summary["concurrency"],
"success_count": summary["success_count"],
"failure_count": summary["failure_count"],
"wall_ms": summary["wall_ms"],
"gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"],
"request_total_ms_avg": summary["request_total_ms_avg"],
"request_total_ms_max": summary["request_total_ms_max"],
"summary_path": str(summary_path),
}, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,887 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import gc
import contextlib
import json
import random
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
if str(gpt_sovits_dir) not in sys.path:
sys.path.append(str(gpt_sovits_dir))
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
SchedulerRequestSpec,
T2SRequestState,
T2SRunningRequest,
_build_decode_batch_from_running,
build_prefill_batch,
prepare_request_state,
run_decode_step_for_running,
run_prefill_step,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
parser.add_argument("--request-manifest", type=Path, default=None)
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
parser.add_argument("--auto-count", type=int, default=4)
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--text", type=str, default=None)
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
parser.add_argument("--text-lang", type=str, default="zh")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--early-stop-num", type=int, default=-1)
parser.add_argument("--max-steps", type=int, default=1500)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--warmup", action="store_true", default=False)
parser.add_argument("--worker-rounds", type=int, default=1)
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
parser.add_argument(
"--output-dir",
type=Path,
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
)
return parser.parse_args()
def set_seed(seed: int, use_cuda: bool) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _sync_device(device: Any) -> None:
try:
device_str = str(device)
if device_str.startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize(device)
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
except Exception:
pass
def bytes_to_mb(num_bytes: int) -> float:
return float(num_bytes) / (1024.0 * 1024.0)
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
if tensor is None:
return 0
return int(tensor.numel() * tensor.element_size())
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
return int(sum(tensor_nbytes(item) for item in items))
def model_nbytes(module: torch.nn.Module) -> int:
total = 0
for parameter in module.parameters():
total += tensor_nbytes(parameter)
for buffer in module.buffers():
total += tensor_nbytes(buffer)
return int(total)
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
modules = {
"t2s_model": tts.t2s_model,
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
"vits_model": tts.vits_model,
"bert_model": tts.bert_model,
"cnhuhbert_model": tts.cnhuhbert_model,
"vocoder": tts.vocoder,
"sv_model": tts.sv_model,
}
by_module = {}
total_bytes = 0
for name, module in modules.items():
module_bytes = model_nbytes(module) if module is not None else 0
by_module[name] = {
"bytes": int(module_bytes),
"mb": bytes_to_mb(module_bytes),
}
total_bytes += module_bytes
return {
"by_module": by_module,
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
}
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
storages: Dict[int, Dict[str, Any]] = {}
tensor_views: List[Dict[str, Any]] = []
for obj in gc.get_objects():
try:
tensor = None
if torch.is_tensor(obj):
tensor = obj
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
tensor = obj.data
if tensor is None or not tensor.is_cuda:
continue
storage = tensor.untyped_storage()
storage_ptr = int(storage.data_ptr())
if storage_ptr not in storages:
storages[storage_ptr] = {
"storage_ptr": storage_ptr,
"storage_bytes": int(storage.nbytes()),
"dtype": str(tensor.dtype),
"shape": list(tensor.shape),
"device": str(tensor.device),
}
tensor_views.append(
{
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"bytes": tensor_nbytes(tensor),
"device": str(tensor.device),
}
)
except Exception:
continue
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
return {
"unique_storage_count": int(len(storage_list)),
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
"top_storages": storage_list[:top_k],
"top_tensor_views": tensor_views[:top_k],
}
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
return [
SchedulerRequestSpec(
request_id="req_000",
ref_audio_path=args.ref_audio,
prompt_text=args.prompt_text,
prompt_lang=args.prompt_lang,
text=text,
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
]
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
if len(wav_paths) < args.auto_count:
raise ValueError(f"auto wav count不足目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if len(text_lines) < args.auto_count:
raise ValueError(f"auto text lines不足文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
specs: List[SchedulerRequestSpec] = []
for index, wav_path in enumerate(wav_paths):
lab_path = wav_path.with_suffix(".lab")
if not lab_path.exists():
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
specs.append(
SchedulerRequestSpec(
request_id=f"req_{index:03d}",
ref_audio_path=wav_path,
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
prompt_lang="zh",
text=text_lines[index],
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
)
return specs
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
if args.request_manifest is not None:
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
specs: List[SchedulerRequestSpec] = []
for index, item in enumerate(raw_requests):
text = item.get("text")
text_file = item.get("text_file")
if text is None and text_file is None:
raise ValueError(f"request[{index}] must provide text or text_file")
if text is None:
text = Path(text_file).read_text(encoding="utf-8").strip()
specs.append(
SchedulerRequestSpec(
request_id=item.get("request_id", f"req_{index:03d}"),
ref_audio_path=Path(item["ref_audio_path"]),
prompt_text=item["prompt_text"],
prompt_lang=item.get("prompt_lang", "zh"),
text=text,
text_lang=item.get("text_lang", "zh"),
top_k=int(item.get("top_k", args.top_k)),
top_p=float(item.get("top_p", args.top_p)),
temperature=float(item.get("temperature", args.temperature)),
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
ready_step=int(item.get("ready_step", 0)),
)
)
return specs
if args.scenario == "single":
return build_single_spec(args)
return build_auto_specs(args)
def load_pipeline(config_path: Path) -> TTS:
tts_config = TTS_Config(str(config_path))
print(tts_config)
return TTS(tts_config)
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
return {
"allocated_mb": 0.0,
"reserved_mb": 0.0,
"max_allocated_mb": 0.0,
"max_reserved_mb": 0.0,
}
_sync_device(device)
return {
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
}
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.reset_peak_memory_stats(device)
before = cuda_mem_snapshot(device)
started = time.perf_counter()
result = fn()
_sync_device(device)
elapsed_ms = (time.perf_counter() - started) * 1000.0
after = cuda_mem_snapshot(device)
after["elapsed_ms"] = float(elapsed_ms)
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
return result, after
class GlobalPeakRecorder:
def __init__(self, device: Any):
self.device = device
self.checkpoints: List[Dict[str, Any]] = []
if str(device).startswith("cuda") and torch.cuda.is_available():
gc.collect()
_sync_device(device)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
def record(self, label: str, **extra: Any) -> None:
snapshot = cuda_mem_snapshot(self.device)
snapshot["label"] = label
snapshot.update(extra)
self.checkpoints.append(snapshot)
def summary(self) -> Dict[str, Any]:
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
return {
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
"peak_label": None if peak is None else peak["label"],
"checkpoints": self.checkpoints,
}
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
per_request = []
total = {
"phones_bytes": 0,
"prompt_phones_bytes": 0,
"all_phones_bytes": 0,
"all_bert_features_bytes": 0,
"prompt_semantic_bytes": 0,
"refer_spec_bytes": 0,
"raw_audio_bytes": 0,
"audio_16k_bytes": 0,
}
for state in states:
spec_audio, audio_16k = state.refer_spec
item = {
"request_id": state.request_id,
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"phones_len": int(state.phones.shape[0]),
"all_phones_len": int(state.all_phones.shape[0]),
"bert_frames": int(state.all_bert_features.shape[-1]),
"phones_bytes": tensor_nbytes(state.phones),
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
"all_phones_bytes": tensor_nbytes(state.all_phones),
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
"refer_spec_bytes": tensor_nbytes(spec_audio),
"audio_16k_bytes": tensor_nbytes(audio_16k),
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
}
for key in total:
total[key] += int(item[key])
per_request.append(item)
total["total_bytes"] = int(sum(total.values()))
total["total_mb"] = bytes_to_mb(total["total_bytes"])
return {"per_request": per_request, "total": total}
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
fields = {
"x_bytes": tensor_nbytes(active_batch.x),
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
"y_sequence_bytes": y_sequence_bytes,
}
fields["total_bytes"] = int(sum(fields.values()))
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
fields["batch_size"] = int(len(active_batch.states))
fields["max_x_len"] = int(active_batch.x.shape[1])
fields["src_len"] = int(active_batch.xy_pos.shape[1])
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
return fields
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
per_request = []
total_private_k_bytes = 0
total_private_v_bytes = 0
total_decode_mask_bytes = 0
total_y_sequence_bytes = 0
for item in running_requests:
k_bytes = tensor_list_nbytes(item.k_cache)
v_bytes = tensor_list_nbytes(item.v_cache)
mask_bytes = tensor_nbytes(item.decode_attn_mask)
y_bytes = tensor_nbytes(item.y_sequence)
total_private_k_bytes += k_bytes
total_private_v_bytes += v_bytes
total_decode_mask_bytes += mask_bytes
total_y_sequence_bytes += y_bytes
per_request.append(
{
"request_id": item.state.request_id,
"step_idx": int(item.step_idx),
"prefix_len": int(item.prefix_len),
"history_len": int(item.y_sequence.shape[0]),
"kv_len": int(item.k_cache[0].shape[1]),
"k_cache_bytes": k_bytes,
"v_cache_bytes": v_bytes,
"decode_mask_bytes": mask_bytes,
"y_sequence_bytes": y_bytes,
}
)
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
return {
"per_request": per_request,
"totals": {
"private_k_cache_bytes": int(total_private_k_bytes),
"private_v_cache_bytes": int(total_private_v_bytes),
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
"decode_mask_bytes": int(total_decode_mask_bytes),
"y_sequence_bytes": int(total_y_sequence_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
},
}
def summarise_decode_batch(
xy_pos: torch.Tensor,
batched_k_cache: Sequence[torch.Tensor],
batched_v_cache: Sequence[torch.Tensor],
batched_decode_attn_mask: Optional[torch.Tensor],
running_requests: Sequence[T2SRunningRequest],
) -> Dict[str, Any]:
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
xy_pos_bytes = tensor_nbytes(xy_pos)
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
return {
"batch_size": int(len(running_requests)),
"xy_pos_bytes": int(xy_pos_bytes),
"batched_k_cache_bytes": int(batched_k_bytes),
"batched_v_cache_bytes": int(batched_v_bytes),
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
"batched_decode_mask_bytes": int(batched_mask_bytes),
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_pos_shape": list(xy_pos.shape),
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
"layer_k_cache_shape": list(batched_k_cache[0].shape),
}
def summarise_decode_outputs(
xy_dec: torch.Tensor,
next_k_cache: Sequence[torch.Tensor],
next_v_cache: Sequence[torch.Tensor],
) -> Dict[str, Any]:
xy_dec_bytes = tensor_nbytes(xy_dec)
next_k_bytes = tensor_list_nbytes(next_k_cache)
next_v_bytes = tensor_list_nbytes(next_v_cache)
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
return {
"xy_dec_bytes": int(xy_dec_bytes),
"next_k_cache_bytes": int(next_k_bytes),
"next_v_cache_bytes": int(next_v_bytes),
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
"total_bytes": int(total_bytes),
"total_mb": bytes_to_mb(total_bytes),
"xy_dec_shape": list(xy_dec.shape),
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
}
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
ranking = [
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
]
ranking.sort(key=lambda item: item[1], reverse=True)
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
phones = state.phones.unsqueeze(0).to(tts.configs.device)
audio_fragment = tts.synthesize_audio_request_local(
semantic_tokens=semantic_tokens,
phones=phones,
prompt_semantic=state.prompt_semantic,
prompt_phones=state.prompt_phones,
refer_spec=state.refer_spec,
raw_audio=state.raw_audio,
raw_sr=state.raw_sr,
speed=1.0,
sample_steps=32,
)
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
return tts.audio_postprocess(
audio=[[audio_fragment]],
sr=int(output_sr),
batch_index_list=None,
speed_factor=1.0,
split_bucket=False,
fragment_interval=0.0,
super_sampling=False,
)
def simulate_worker_end_to_end(
tts: TTS,
specs: Sequence[SchedulerRequestSpec],
max_steps: int,
rounds: int,
grad_mode: str = "default",
) -> Dict[str, Any]:
device = tts.configs.device
recorder = GlobalPeakRecorder(device)
recorder.record("after_model_load")
state_map: Dict[str, T2SRequestState] = {}
per_round: List[Dict[str, Any]] = []
for round_index in range(rounds):
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
with grad_context():
states = [prepare_request_state(tts, spec) for spec in specs]
state_map = {state.request_id: state for state in states}
recorder.record(
"after_prepare_states",
round_index=int(round_index),
request_count=int(len(states)),
grad_mode=grad_mode,
)
pending = list(states)
running_requests: List[T2SRunningRequest] = []
round_events: List[Dict[str, Any]] = []
current_tick = 0
while pending or running_requests:
admitted = pending
pending = []
if admitted:
recorder.record(
"before_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_count=int(len(admitted)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
recorder.record(
"after_prefill",
round_index=int(round_index),
tick=int(current_tick),
admitted_running_count=int(len(admitted_running)),
admitted_finished_count=int(len(admitted_finished)),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "prefill",
"admitted_count": int(len(admitted)),
"admitted_running_count": int(len(admitted_running)),
"admitted_finished_count": int(len(admitted_finished)),
}
)
for item in admitted_finished:
recorder.record(
"before_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_prefill_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
running_requests.extend(admitted_running)
recorder.record(
"after_extend_running",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
if running_requests:
recorder.record(
"before_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
grad_mode=grad_mode,
)
with grad_context():
running_requests, step_finished = run_decode_step_for_running(
tts.t2s_model.model,
running_requests,
max_steps=max_steps,
)
recorder.record(
"after_decode",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_count=int(len(step_finished)),
grad_mode=grad_mode,
)
round_events.append(
{
"tick": int(current_tick),
"event": "decode",
"running_count_after_decode": int(len(running_requests)),
"finished_count": int(len(step_finished)),
}
)
for item in step_finished:
recorder.record(
"before_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
semantic_len=int(item.semantic_tokens.shape[0]),
grad_mode=grad_mode,
)
with grad_context():
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
recorder.record(
"after_synth_decode_finished",
round_index=int(round_index),
tick=int(current_tick),
running_count=int(len(running_requests)),
finished_request_id=item.request_id,
sample_rate=int(sample_rate),
audio_samples=int(audio_data.shape[0]),
grad_mode=grad_mode,
)
current_tick += 1
recorder.record(
"after_round_complete",
round_index=int(round_index),
running_count=0,
grad_mode=grad_mode,
)
per_round.append(
{
"round_index": int(round_index),
"events": round_events,
}
)
return {
"grad_mode": grad_mode,
"rounds": per_round,
"timeline": recorder.summary(),
}
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
tts = load_pipeline(args.config)
model = tts.t2s_model.model
device = tts.configs.device
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
set_seed(args.seed, use_cuda)
specs = load_request_specs(args)
if args.early_stop_num == -1:
for spec in specs:
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
if args.warmup and specs:
warmup_spec = specs[:1]
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
gc.collect()
if use_cuda:
torch.cuda.empty_cache()
_sync_device(device)
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
request_state_summary = summarise_state_tensors(states)
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
running_requests, finished_items = prefill_result
running_requests_summary = summarise_running_requests(running_requests)
finished_after_prefill_summary = [
{
"request_id": item.request_id,
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
"semantic_len": int(item.semantic_tokens.shape[0]),
}
for item in finished_items
]
if not running_requests:
raise RuntimeError(f"prefill 后没有 running requests全部在首步结束: {[item.request_id for item in finished_items]}")
decode_batch_result, decode_batch_mem = stage_run(
device,
lambda: _build_decode_batch_from_running(model, running_requests),
)
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
decode_batch_tensor_summary = summarise_decode_batch(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
running_requests,
)
decode_out_result, decode_step_mem = stage_run(
device,
lambda: model.t2s_transformer.decode_next_token(
xy_pos,
batched_k_cache,
batched_v_cache,
batched_decode_attn_mask,
),
)
xy_dec, next_k_cache, next_v_cache = decode_out_result
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
del active_batch
del running_requests
del finished_items
del xy_pos
del batched_k_cache
del batched_v_cache
del batched_decode_attn_mask
del xy_dec
del next_k_cache
del next_v_cache
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
end_to_end_worker = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode=args.worker_grad_mode,
)
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
worker_inference_mode = None
if args.compare_worker_grad_modes:
gc.collect()
if use_cuda:
_sync_device(device)
torch.cuda.empty_cache()
worker_inference_mode = simulate_worker_end_to_end(
tts=tts,
specs=specs,
max_steps=args.max_steps,
rounds=args.worker_rounds,
grad_mode="inference_mode",
)
summary = {
"meta": {
"scenario": args.scenario if args.request_manifest is None else "manifest",
"seed": int(args.seed),
"device": str(device),
"dtype": str(next(model.parameters()).dtype),
"request_count": int(len(specs)),
"num_layers": int(model.num_layers),
"num_heads": int(model.num_head),
"model_dim": int(model.model_dim),
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
},
"loaded_module_weights": build_module_weight_summary(tts),
"requests": [
{
"request_id": spec.request_id,
"ref_audio_path": str(spec.ref_audio_path),
"prompt_text": spec.prompt_text,
"text": spec.text,
}
for spec in specs
],
"prepare_stage": {
"memory": prepare_mem,
"request_state": request_state_summary,
},
"prefill_batch": {
"memory": prefill_batch_mem,
"tensor_bytes": prefill_batch_tensor_summary,
},
"prefill_step": {
"memory": prefill_step_mem,
"running_requests": running_requests_summary,
"finished_after_prefill": finished_after_prefill_summary,
},
"decode_batch": {
"memory": decode_batch_mem,
"tensor_bytes": decode_batch_tensor_summary,
},
"decode_outputs": {
"memory": decode_step_mem,
"tensor_bytes": decode_output_tensor_summary,
},
"end_to_end_worker": end_to_end_worker,
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
"end_to_end_worker_inference_mode": worker_inference_mode,
}
summary["top_rankings"] = top_rankings(summary)
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
print("[top_rankings]")
for item in summary["top_rankings"]:
print(f"- {item['name']}: {item['mb']:.3f} MB")
print("[worker_peak]")
print(
json.dumps(
{
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
if worker_inference_mode is not None:
print("[worker_peak_inference_mode]")
print(
json.dumps(
{
"peak_label": worker_inference_mode["timeline"]["peak_label"],
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
},
ensure_ascii=False,
indent=2,
)
)
print(f"[summary] {summary_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,180 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import json
import random
import sys
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
import torch
ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.append(str(ROOT_DIR))
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
if str(gpt_sovits_dir) not in sys.path:
sys.path.append(str(gpt_sovits_dir))
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
SchedulerRequestSpec,
T2SFinishedItem,
T2SRequestState,
prepare_request_state,
run_scheduler_continuous,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.")
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
parser.add_argument("--request-manifest", type=Path, default=None)
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
parser.add_argument("--prompt-lang", type=str, default="zh")
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
parser.add_argument("--text", type=str, default=None)
parser.add_argument("--text-lang", type=str, default="en")
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.35)
parser.add_argument("--early-stop-num", type=int, default=-1)
parser.add_argument("--max-steps", type=int, default=1500)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run")
return parser.parse_args()
def set_seed(seed: int, use_cuda: bool) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_pipeline(config_path: Path):
try:
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。"
) from exc
tts_config = TTS_Config(str(config_path))
print(tts_config)
return TTS(tts_config)
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
if args.request_manifest is not None:
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
specs: List[SchedulerRequestSpec] = []
for index, item in enumerate(raw_requests):
text = item.get("text")
text_file = item.get("text_file")
if text is None and text_file is None:
raise ValueError(f"request[{index}] must provide text or text_file")
if text is None:
text = Path(text_file).read_text(encoding="utf-8")
specs.append(
SchedulerRequestSpec(
request_id=item.get("request_id", f"req_{index:03d}"),
ref_audio_path=Path(item["ref_audio_path"]),
prompt_text=item["prompt_text"],
prompt_lang=item.get("prompt_lang", "zh"),
text=text,
text_lang=item.get("text_lang", "zh"),
top_k=int(item.get("top_k", args.top_k)),
top_p=float(item.get("top_p", args.top_p)),
temperature=float(item.get("temperature", args.temperature)),
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
ready_step=int(item.get("ready_step", 0)),
)
)
return specs
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8")
return [
SchedulerRequestSpec(
request_id="req_000",
ref_audio_path=args.ref_audio,
prompt_text=args.prompt_text,
prompt_lang=args.prompt_lang,
text=text,
text_lang=args.text_lang,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
early_stop_num=args.early_stop_num,
ready_step=0,
)
]
def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]:
return [
{
"request_id": state.request_id,
"ready_step": int(state.ready_step),
"ref_audio_path": str(state.ref_audio_path),
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
"all_phone_len": int(state.all_phones.shape[0]),
"bert_len": int(state.all_bert_features.shape[-1]),
"norm_text": state.norm_text,
}
for state in states
]
def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]:
return [
{
"request_id": item.request_id,
"semantic_len": int(item.semantic_tokens.shape[0]),
"finish_idx": int(item.finish_idx),
"finish_reason": item.finish_reason,
}
for item in items
]
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
tts = load_pipeline(args.config)
model = tts.t2s_model.model
use_cuda = str(tts.configs.device).startswith("cuda")
set_seed(args.seed, use_cuda)
request_specs = load_request_specs(args)
states = [prepare_request_state(tts, spec) for spec in request_specs]
finished = run_scheduler_continuous(model, states, max_steps=args.max_steps)
summary = {
"request_count": len(states),
"max_steps": args.max_steps,
"requests": summarise_requests(states),
"finished": summarise_finished(finished),
}
output_path = args.output_dir / "scheduler_prototype_summary.json"
output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary, ensure_ascii=False, indent=2))
print(f"[saved] {output_path}")
if __name__ == "__main__":
try:
main()
except ModuleNotFoundError as exc:
print(f"[error] {exc}")
raise SystemExit(1) from None