mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-08 07:38:18 +08:00
Merge 845b181360b5c5f4a5ed827ca01cfd5dfd8f58b1 into 2d9193b0d3c0eae0c3a14d8c68a839f1bae157dc
This commit is contained in:
commit
1c07b3339e
@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
blocks.append(block)
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
self.last_infer_stats = {}
|
||||
|
||||
def _set_last_infer_stats(self, stats):
|
||||
self.last_infer_stats = stats
|
||||
|
||||
def get_last_infer_stats(self):
|
||||
return dict(self.last_infer_stats)
|
||||
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||
x = self.ar_text_embedding(x)
|
||||
@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module):
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs,
|
||||
):
|
||||
requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True))
|
||||
if prompts is None:
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "batch_infer_prompt_free_fallback",
|
||||
"requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath,
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": None,
|
||||
"fastpath_hit": False,
|
||||
"generated_token_count": 0,
|
||||
"generated_token_count_list": [],
|
||||
}
|
||||
)
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(
|
||||
x,
|
||||
@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
)
|
||||
|
||||
max_len = kwargs.get("max_len", x_lens.max())
|
||||
enable_mask_free_fastpath = requested_enable_mask_free_fastpath
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None] * y.shape[0]
|
||||
decode_attn_mask = attn_mask
|
||||
prefill_after_mask_all_visible = None
|
||||
fastpath_hit = False
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(
|
||||
xy_pos, k_cache, v_cache, decode_attn_mask
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
prefill_after_mask_all_visible = not attn_mask.any().item()
|
||||
if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible:
|
||||
decode_attn_mask = None
|
||||
fastpath_hit = True
|
||||
else:
|
||||
decode_attn_mask = attn_mask
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
if decode_attn_mask is not None:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
decode_attn_mask = attn_mask
|
||||
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if decode_attn_mask is not None:
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
decode_attn_mask = attn_mask
|
||||
if k_cache is not None:
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "batch_infer",
|
||||
"requested_enable_mask_free_fastpath": enable_mask_free_fastpath,
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": prefill_after_mask_all_visible,
|
||||
"fastpath_hit": fastpath_hit,
|
||||
"generated_token_count": int(sum(idx_list)),
|
||||
"generated_token_count_list": [int(item) for item in idx_list],
|
||||
"max_len": int(max_len),
|
||||
}
|
||||
)
|
||||
if ref_free:
|
||||
return y_list, [0] * x.shape[0]
|
||||
# print(idx_list)
|
||||
@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "naive_batched",
|
||||
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": None,
|
||||
"fastpath_hit": False,
|
||||
"generated_token_count": int(sum(idx_list)),
|
||||
"generated_token_count_list": [int(item) for item in idx_list],
|
||||
}
|
||||
)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive(
|
||||
@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
|
||||
if not streaming_mode:
|
||||
generated_token_count = max(int(y.shape[1] - prefix_len), 0)
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "naive",
|
||||
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
|
||||
"batch_size": int(x.shape[0]),
|
||||
"prefill_after_mask_all_visible": True if prompts is not None else None,
|
||||
"fastpath_hit": True if prompts is not None else False,
|
||||
"generated_token_count": generated_token_count,
|
||||
"generated_token_count_list": [generated_token_count],
|
||||
}
|
||||
)
|
||||
if ref_free:
|
||||
yield y, 0
|
||||
yield y, idx
|
||||
|
||||
@ -5,6 +5,7 @@ import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
|
||||
import torchaudio
|
||||
@ -33,7 +34,12 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from tools.audio_sr import AP_BWE
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor, StageLimiter
|
||||
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
|
||||
from TTS_infer_pack.prepare_ref_semantic_batch_worker import (
|
||||
PrepareRefSemanticBatchWorker,
|
||||
prepare_prompt_semantic_wav16k,
|
||||
)
|
||||
from sv import SV
|
||||
|
||||
resample_transform_dict = {}
|
||||
@ -442,11 +448,56 @@ class TTS:
|
||||
"upsample_rate": None,
|
||||
"overlapped_len": None,
|
||||
}
|
||||
self.prepare_bert_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_BERT_SLOTS", "1")))
|
||||
self.prepare_ref_audio_stage_limiter = StageLimiter(int(os.environ.get("GPTSOVITS_PREPARE_REF_SLOTS", "2")))
|
||||
self.prepare_bert_batch_worker = None
|
||||
self.prepare_ref_semantic_batch_worker = None
|
||||
default_text_cpu_workers = 16
|
||||
self.prepare_text_cpu_workers = max(
|
||||
0,
|
||||
int(os.environ.get("GPTSOVITS_PREPARE_TEXT_CPU_WORKERS", str(default_text_cpu_workers))),
|
||||
)
|
||||
self.prepare_text_cpu_executor = None
|
||||
if self.prepare_text_cpu_workers > 0:
|
||||
self.prepare_text_cpu_executor = ThreadPoolExecutor(
|
||||
max_workers=self.prepare_text_cpu_workers,
|
||||
thread_name_prefix="prepare-text-cpu",
|
||||
)
|
||||
|
||||
self._init_models()
|
||||
|
||||
if os.environ.get("GPTSOVITS_PREPARE_BERT_BATCHING", "1") != "0":
|
||||
self.prepare_bert_batch_worker = PrepareBertBatchWorker(
|
||||
bert_model=self.bert_model,
|
||||
tokenizer=self.bert_tokenizer,
|
||||
device=self.configs.device,
|
||||
stage_limiter=self.prepare_bert_stage_limiter,
|
||||
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_WINDOW_MS", "5")),
|
||||
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_ITEMS", "16")),
|
||||
max_batch_tokens=int(os.environ.get("GPTSOVITS_PREPARE_BERT_BATCH_MAX_TOKENS", "4096")),
|
||||
)
|
||||
if os.environ.get("GPTSOVITS_PREPARE_REF_BATCHING", "0") != "0":
|
||||
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_SAMPLES")
|
||||
if ref_max_batch_samples is None:
|
||||
ref_max_batch_samples = os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_FRAMES", "960000")
|
||||
self.prepare_ref_semantic_batch_worker = PrepareRefSemanticBatchWorker(
|
||||
ssl_model=self.cnhuhbert_model,
|
||||
vits_model=self.vits_model,
|
||||
device=self.configs.device,
|
||||
is_half=self.configs.is_half,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
stage_limiter=self.prepare_ref_audio_stage_limiter,
|
||||
batch_window_ms=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_WINDOW_MS", "5")),
|
||||
max_batch_items=int(os.environ.get("GPTSOVITS_PREPARE_REF_BATCH_MAX_ITEMS", "8")),
|
||||
max_batch_samples=int(ref_max_batch_samples),
|
||||
)
|
||||
|
||||
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
|
||||
self.bert_model, self.bert_tokenizer, self.configs.device
|
||||
self.bert_model,
|
||||
self.bert_tokenizer,
|
||||
self.configs.device,
|
||||
bert_stage_limiter=self.prepare_bert_stage_limiter,
|
||||
bert_batch_worker=self.prepare_bert_batch_worker,
|
||||
)
|
||||
|
||||
self.prompt_cache: dict = {
|
||||
@ -755,33 +806,52 @@ class TTS:
|
||||
Args:
|
||||
ref_audio_path: str, the path of the reference audio.
|
||||
"""
|
||||
self._set_prompt_semantic(ref_audio_path)
|
||||
self._set_ref_spec(ref_audio_path)
|
||||
bundle = self.extract_ref_audio_bundle(ref_audio_path)
|
||||
if self.prompt_cache["refer_spec"] in [[], None]:
|
||||
self.prompt_cache["refer_spec"] = [bundle["refer_spec"]]
|
||||
else:
|
||||
self.prompt_cache["refer_spec"][0] = bundle["refer_spec"]
|
||||
self.prompt_cache["prompt_semantic"] = bundle["prompt_semantic"]
|
||||
self.prompt_cache["raw_audio"] = bundle["raw_audio"]
|
||||
self.prompt_cache["raw_sr"] = bundle["raw_sr"]
|
||||
self._set_ref_audio_path(ref_audio_path)
|
||||
|
||||
def _set_ref_audio_path(self, ref_audio_path):
|
||||
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||||
|
||||
def _set_ref_spec(self, ref_audio_path):
|
||||
spec_audio = self._get_ref_spec(ref_audio_path)
|
||||
if self.prompt_cache["refer_spec"] in [[], None]:
|
||||
self.prompt_cache["refer_spec"] = [spec_audio]
|
||||
else:
|
||||
self.prompt_cache["refer_spec"][0] = spec_audio
|
||||
|
||||
def _get_ref_spec(self, ref_audio_path):
|
||||
def _load_ref_audio_raw(self, ref_audio_path: str):
|
||||
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
|
||||
raw_audio = raw_audio.to(self.configs.device).float()
|
||||
self.prompt_cache["raw_audio"] = raw_audio
|
||||
self.prompt_cache["raw_sr"] = raw_sr
|
||||
return raw_audio.float(), int(raw_sr)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_from_prepared_wav16k(self, wav16k: torch.Tensor):
|
||||
wav16k = wav16k.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
wav16k = wav16k.half()
|
||||
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
return codes[0, 0].to(self.configs.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _extract_prompt_semantic_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
wav16k = prepare_prompt_semantic_wav16k(
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
zero_wav_samples=int(self.configs.sampling_rate * 0.3),
|
||||
)
|
||||
return self._extract_prompt_semantic_from_prepared_wav16k(wav16k)
|
||||
|
||||
def extract_prompt_semantic(self, ref_wav_path: str):
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_wav_path)
|
||||
return self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
|
||||
|
||||
def _extract_ref_spec_from_raw(self, raw_audio: torch.Tensor, raw_sr: int):
|
||||
raw_audio_device = raw_audio.to(self.configs.device).float()
|
||||
|
||||
if raw_sr != self.configs.sampling_rate:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
audio = raw_audio_device
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||||
else:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
audio = raw_audio_device
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
@ -804,33 +874,163 @@ class TTS:
|
||||
audio = audio.half()
|
||||
else:
|
||||
audio = None
|
||||
return spec, audio, raw_audio, raw_sr
|
||||
|
||||
def extract_ref_spec(self, ref_audio_path: str):
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||
return self._extract_ref_spec_from_raw(raw_audio, raw_sr)
|
||||
|
||||
def extract_ref_audio_bundle(self, ref_audio_path: str):
|
||||
load_start = time.perf_counter()
|
||||
raw_audio, raw_sr = self._load_ref_audio_raw(ref_audio_path)
|
||||
load_ms = (time.perf_counter() - load_start) * 1000.0
|
||||
if self.prepare_ref_semantic_batch_worker is None:
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as limiter_stats:
|
||||
prompt_semantic_start = time.perf_counter()
|
||||
prompt_semantic = self._extract_prompt_semantic_from_raw(raw_audio, raw_sr)
|
||||
prompt_semantic_ms = (time.perf_counter() - prompt_semantic_start) * 1000.0
|
||||
ref_spec_start = time.perf_counter()
|
||||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||||
audio_stage_wait_ms = float(limiter_stats["wait_ms"])
|
||||
audio_stage_slots = float(limiter_stats["slots"])
|
||||
audio_stage_inflight_peak = float(limiter_stats["peak_inflight"])
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||||
"prompt_semantic_forward_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
|
||||
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"prompt_semantic_batch_size": 1.0,
|
||||
"prompt_semantic_batch_samples": 0.0,
|
||||
}
|
||||
ref_spec_wait_ms = 0.0
|
||||
return {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"refer_spec": refer_spec,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_ms": load_ms,
|
||||
"audio_stage_wait_ms": audio_stage_wait_ms,
|
||||
"audio_stage_slots": audio_stage_slots,
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_forward_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_scatter_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_slots": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)
|
||||
),
|
||||
"prompt_semantic_stage_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||||
"prompt_semantic_batch_samples": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||||
),
|
||||
"ref_spec_wait_ms": ref_spec_wait_ms,
|
||||
"ref_spec_ms": ref_spec_ms,
|
||||
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
|
||||
},
|
||||
}
|
||||
|
||||
prompt_semantic_profile = {
|
||||
"prompt_semantic_wait_ms": 0.0,
|
||||
"prompt_semantic_cpu_prepare_ms": 0.0,
|
||||
"prompt_semantic_forward_ms": 0.0,
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_stage_slots": 0.0,
|
||||
"prompt_semantic_stage_inflight_peak": 0.0,
|
||||
"prompt_semantic_batch_size": 1.0,
|
||||
"prompt_semantic_batch_samples": 0.0,
|
||||
}
|
||||
if self.prepare_ref_semantic_batch_worker is not None:
|
||||
prompt_semantic, worker_profile = self.prepare_ref_semantic_batch_worker.submit(raw_audio, raw_sr)
|
||||
prompt_semantic_profile.update(worker_profile)
|
||||
prompt_semantic_ms = (
|
||||
float(prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0))
|
||||
+ float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0))
|
||||
)
|
||||
with self.prepare_ref_audio_stage_limiter.enter() as ref_spec_limiter_stats:
|
||||
ref_spec_start = time.perf_counter()
|
||||
refer_spec = self._extract_ref_spec_from_raw(raw_audio, raw_sr)[:2]
|
||||
ref_spec_ms = (time.perf_counter() - ref_spec_start) * 1000.0
|
||||
audio_stage_wait_ms = float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)) + float(
|
||||
ref_spec_limiter_stats["wait_ms"]
|
||||
)
|
||||
audio_stage_slots = max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
float(ref_spec_limiter_stats["slots"]),
|
||||
)
|
||||
audio_stage_inflight_peak = max(
|
||||
float(prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||||
float(ref_spec_limiter_stats["peak_inflight"]),
|
||||
)
|
||||
return {
|
||||
"prompt_semantic": prompt_semantic,
|
||||
"refer_spec": refer_spec,
|
||||
"raw_audio": raw_audio,
|
||||
"raw_sr": raw_sr,
|
||||
"profile": {
|
||||
"audio_load_ms": load_ms,
|
||||
"audio_stage_wait_ms": audio_stage_wait_ms,
|
||||
"audio_stage_slots": audio_stage_slots,
|
||||
"audio_stage_inflight_peak": audio_stage_inflight_peak,
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(prompt_semantic_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)
|
||||
),
|
||||
"prompt_semantic_forward_ms": float(prompt_semantic_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(prompt_semantic_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(prompt_semantic_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
"prompt_semantic_stage_inflight_peak": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_stage_inflight_peak", 0.0)
|
||||
),
|
||||
"prompt_semantic_batch_size": float(prompt_semantic_profile.get("prompt_semantic_batch_size", 1.0)),
|
||||
"prompt_semantic_batch_samples": float(
|
||||
prompt_semantic_profile.get("prompt_semantic_batch_samples", 0.0)
|
||||
),
|
||||
"ref_spec_wait_ms": float(ref_spec_limiter_stats["wait_ms"]),
|
||||
"ref_spec_ms": ref_spec_ms,
|
||||
"bundle_total_ms": load_ms + audio_stage_wait_ms + prompt_semantic_ms + ref_spec_ms,
|
||||
},
|
||||
}
|
||||
|
||||
def extract_text_features(self, text: str, language: str, profile: dict | None = None):
|
||||
return self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||
text, language, self.configs.version, profile=profile
|
||||
)
|
||||
|
||||
def _set_ref_audio_path(self, ref_audio_path):
|
||||
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||||
|
||||
def _set_ref_spec(self, ref_audio_path):
|
||||
spec_audio = self._get_ref_spec(ref_audio_path)
|
||||
if self.prompt_cache["refer_spec"] in [[], None]:
|
||||
self.prompt_cache["refer_spec"] = [spec_audio]
|
||||
else:
|
||||
self.prompt_cache["refer_spec"][0] = spec_audio
|
||||
|
||||
def _get_ref_spec(self, ref_audio_path):
|
||||
spec, audio, raw_audio, raw_sr = self.extract_ref_spec(ref_audio_path)
|
||||
self.prompt_cache["raw_audio"] = raw_audio
|
||||
self.prompt_cache["raw_sr"] = raw_sr
|
||||
return spec, audio
|
||||
|
||||
def _set_prompt_semantic(self, ref_wav_path: str):
|
||||
zero_wav = np.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=np.float16 if self.configs.is_half else np.float32,
|
||||
)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
wav16k = wav16k.to(self.configs.device)
|
||||
zero_wav_torch = zero_wav_torch.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
wav16k = wav16k.half()
|
||||
zero_wav_torch = zero_wav_torch.half()
|
||||
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
|
||||
prompt_semantic = codes[0, 0].to(self.configs.device)
|
||||
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
||||
prompt_semantic = self.extract_prompt_semantic(ref_wav_path)
|
||||
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
||||
|
||||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
|
||||
seq = sequences[0]
|
||||
@ -1227,6 +1427,9 @@ class TTS:
|
||||
###### inference ######
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
t2s_observe_batch_count = 0
|
||||
t2s_observe_fastpath_hits = 0
|
||||
t2s_observe_generated_tokens = 0
|
||||
audio = []
|
||||
is_first_package = True
|
||||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||||
@ -1280,6 +1483,29 @@ class TTS:
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
if hasattr(self.t2s_model.model, "get_last_infer_stats"):
|
||||
t2s_stats = self.t2s_model.model.get_last_infer_stats()
|
||||
if t2s_stats:
|
||||
generated_token_count = int(t2s_stats.get("generated_token_count", 0))
|
||||
t2s_total_ms = (t4 - t3) * 1000.0
|
||||
avg_decode_ms_per_token = (
|
||||
t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0
|
||||
)
|
||||
t2s_observe_batch_count += 1
|
||||
t2s_observe_generated_tokens += generated_token_count
|
||||
if bool(t2s_stats.get("fastpath_hit", False)):
|
||||
t2s_observe_fastpath_hits += 1
|
||||
print(
|
||||
"[t2s_observe] "
|
||||
f"mode={t2s_stats.get('infer_mode')} "
|
||||
f"batch_size={t2s_stats.get('batch_size')} "
|
||||
f"tokens={generated_token_count} "
|
||||
f"t2s_ms={t2s_total_ms:.3f} "
|
||||
f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} "
|
||||
f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} "
|
||||
f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} "
|
||||
f"fastpath_hit={t2s_stats.get('fastpath_hit')}"
|
||||
)
|
||||
|
||||
|
||||
batch_audio_fragment = []
|
||||
@ -1500,6 +1726,18 @@ class TTS:
|
||||
|
||||
if not (return_fragment or streaming_mode):
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
if t2s_observe_batch_count > 0:
|
||||
request_avg_decode_ms_per_token = (
|
||||
(t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0
|
||||
)
|
||||
print(
|
||||
"[t2s_request_observe] "
|
||||
f"batches={t2s_observe_batch_count} "
|
||||
f"fastpath_hits={t2s_observe_fastpath_hits} "
|
||||
f"generated_tokens={t2s_observe_generated_tokens} "
|
||||
f"t2s_total_ms={t_34 * 1000.0:.3f} "
|
||||
f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}"
|
||||
)
|
||||
if len(audio) == 0:
|
||||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||
return
|
||||
@ -1663,6 +1901,116 @@ class TTS:
|
||||
|
||||
return audio
|
||||
|
||||
def using_vocoder_synthesis_request_local(
|
||||
self,
|
||||
semantic_tokens: torch.Tensor,
|
||||
phones: torch.Tensor,
|
||||
prompt_semantic: torch.Tensor,
|
||||
prompt_phones: torch.Tensor,
|
||||
refer_audio_spec: torch.Tensor,
|
||||
raw_audio: torch.Tensor,
|
||||
raw_sr: int,
|
||||
speed: float = 1.0,
|
||||
sample_steps: int = 32,
|
||||
):
|
||||
prompt_semantic_tokens = prompt_semantic.unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
prompt_phones = prompt_phones.unsqueeze(0).to(self.configs.device)
|
||||
refer_audio_spec = refer_audio_spec.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio = raw_audio.to(self.configs.device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
|
||||
tgt_sr = 24000 if self.configs.version == "v3" else 32000
|
||||
if raw_sr != tgt_sr:
|
||||
ref_audio = resample(ref_audio, raw_sr, tgt_sr, self.configs.device)
|
||||
|
||||
mel_spec_fn = mel_fn if self.configs.version == "v3" else mel_fn_v4
|
||||
mel2 = mel_spec_fn(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
T_ref = self.vocoder_configs["T_ref"]
|
||||
T_chunk = self.vocoder_configs["T_chunk"]
|
||||
if T_min > T_ref:
|
||||
mel2 = mel2[:, :, -T_ref:]
|
||||
fea_ref = fea_ref[:, :, -T_ref:]
|
||||
T_min = T_ref
|
||||
chunk_len = T_chunk - T_min
|
||||
|
||||
mel2 = mel2.to(self.precision)
|
||||
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while 1:
|
||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||
if fea_todo_chunk.shape[-1] == 0:
|
||||
break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
|
||||
cfm_res = self.vits_model.cfm.inference(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||
)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
|
||||
cfm_resss.append(cfm_res)
|
||||
cfm_res = torch.cat(cfm_resss, 2)
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
|
||||
with torch.inference_mode():
|
||||
wav_gen = self.vocoder(cfm_res)
|
||||
audio = wav_gen[0][0]
|
||||
|
||||
return audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthesize_audio_request_local(
|
||||
self,
|
||||
semantic_tokens: torch.Tensor,
|
||||
phones: torch.Tensor,
|
||||
prompt_semantic: torch.Tensor,
|
||||
prompt_phones: torch.Tensor,
|
||||
refer_spec: tuple,
|
||||
raw_audio: torch.Tensor,
|
||||
raw_sr: int,
|
||||
speed: float = 1.0,
|
||||
sample_steps: int = 32,
|
||||
):
|
||||
refer_audio_spec, audio_tensor = refer_spec
|
||||
if not self.configs.use_vocoder:
|
||||
refer_audio_spec_list = [refer_audio_spec.to(dtype=self.precision, device=self.configs.device)]
|
||||
sv_emb = None
|
||||
if self.is_v2pro:
|
||||
if audio_tensor is None:
|
||||
raise ValueError(i18n("v2Pro request-local synthesis 缺少 16k 参考音频"))
|
||||
sv_emb = self.sv_model.compute_embedding3(audio_tensor).to(self.configs.device)
|
||||
return self.vits_model.decode(
|
||||
semantic_tokens,
|
||||
phones,
|
||||
refer_audio_spec_list,
|
||||
speed=speed,
|
||||
sv_emb=sv_emb,
|
||||
).detach()[0, 0, :]
|
||||
|
||||
return self.using_vocoder_synthesis_request_local(
|
||||
semantic_tokens=semantic_tokens,
|
||||
phones=phones,
|
||||
prompt_semantic=prompt_semantic,
|
||||
prompt_phones=prompt_phones,
|
||||
refer_audio_spec=refer_audio_spec,
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
speed=speed,
|
||||
sample_steps=sample_steps,
|
||||
)
|
||||
|
||||
def using_vocoder_synthesis_batched_infer(
|
||||
self,
|
||||
idx_list: List[int],
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -16,6 +18,7 @@ from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
@ -49,12 +52,60 @@ def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
class StageLimiter:
|
||||
def __init__(self, slots: int):
|
||||
self.slots = max(1, int(slots))
|
||||
self.semaphore = threading.BoundedSemaphore(self.slots)
|
||||
self.lock = threading.Lock()
|
||||
self.inflight = 0
|
||||
self.peak_inflight = 0
|
||||
|
||||
@contextmanager
|
||||
def enter(self):
|
||||
wait_start = time.perf_counter()
|
||||
self.semaphore.acquire()
|
||||
wait_ms = (time.perf_counter() - wait_start) * 1000.0
|
||||
with self.lock:
|
||||
self.inflight += 1
|
||||
current_inflight = self.inflight
|
||||
if current_inflight > self.peak_inflight:
|
||||
self.peak_inflight = current_inflight
|
||||
peak_inflight = self.peak_inflight
|
||||
try:
|
||||
yield {
|
||||
"wait_ms": wait_ms,
|
||||
"inflight": current_inflight,
|
||||
"peak_inflight": peak_inflight,
|
||||
"slots": self.slots,
|
||||
}
|
||||
finally:
|
||||
with self.lock:
|
||||
self.inflight = max(0, self.inflight - 1)
|
||||
self.semaphore.release()
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.lock:
|
||||
return {
|
||||
"slots": self.slots,
|
||||
"inflight": self.inflight,
|
||||
"peak_inflight": self.peak_inflight,
|
||||
}
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||
def __init__(
|
||||
self,
|
||||
bert_model: AutoModelForMaskedLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
device: torch.device,
|
||||
bert_stage_limiter: StageLimiter | None = None,
|
||||
bert_batch_worker: PrepareBertBatchWorker | None = None,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
self.bert_stage_limiter = bert_stage_limiter
|
||||
self.bert_batch_worker = bert_batch_worker
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
@ -115,86 +166,136 @@ class TextPreprocessor:
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1"
|
||||
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
return self.get_phones_and_bert(text, language, version, profile=profile)
|
||||
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
with self.bert_lock:
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||
)
|
||||
if same_group:
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text,"ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text,"ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
else:
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
return textlist, langlist
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||
def get_phones_and_bert(
|
||||
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
|
||||
):
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
textlist, langlist = self._split_text_by_language(text, language)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for segment_text, segment_lang in zip(textlist, langlist):
|
||||
phones, word2ph, norm_text = self.clean_text_inf(segment_text, segment_lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, segment_lang, profile=profile)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
return phones, bert, norm_text
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True, profile=profile)
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
return phones, bert, norm_text
|
||||
|
||||
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
|
||||
if profile is None:
|
||||
return
|
||||
profile[key] = float(profile.get(key, 0.0)) + float(value)
|
||||
|
||||
def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None:
|
||||
if profile is None:
|
||||
return
|
||||
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
|
||||
if self.bert_batch_worker is not None:
|
||||
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
|
||||
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
|
||||
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
|
||||
self._update_profile_peak(
|
||||
profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0)
|
||||
)
|
||||
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
|
||||
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
|
||||
return feature
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
|
||||
if self.bert_stage_limiter is None:
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.bert_stage_limiter.enter() as limiter_stats:
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"])
|
||||
self._accumulate_profile(profile, "bert_forward_ms", forward_ms)
|
||||
self._accumulate_profile(profile, "bert_calls", 1.0)
|
||||
self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"])
|
||||
if profile is not None:
|
||||
profile["bert_stage_slots"] = float(limiter_stats["slots"])
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
@ -209,10 +310,10 @@ class TextPreprocessor:
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str, profile: Dict | None = None):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
|
||||
else:
|
||||
feature = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
@ -236,4 +337,4 @@ class TextPreprocessor:
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
return result
|
||||
|
||||
@ -1 +1,11 @@
|
||||
from . import TTS, text_segmentation_method
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
__all__ = ["TTS", "TextPreprocessor", "text_segmentation_method", "t2s_scheduler"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in __all__:
|
||||
return importlib.import_module(f"{__name__}.{name}")
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
197
GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py
Normal file
197
GPT_SoVITS/TTS_infer_pack/prepare_bert_batch_worker.py
Normal file
@ -0,0 +1,197 @@
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertFeatureTask:
|
||||
norm_text: str
|
||||
word2ph: List[int]
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
result_feature: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareBertBatchWorker:
|
||||
def __init__(
|
||||
self,
|
||||
bert_model,
|
||||
tokenizer,
|
||||
device,
|
||||
stage_limiter=None,
|
||||
batch_window_ms: int = 5,
|
||||
max_batch_items: int = 16,
|
||||
max_batch_tokens: int = 4096,
|
||||
):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.stage_limiter = stage_limiter
|
||||
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
|
||||
self.max_batch_items = max(1, int(max_batch_items))
|
||||
self.max_batch_tokens = max(16, int(max_batch_tokens))
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[BertFeatureTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.total_batches = 0
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_peak = 0
|
||||
self.worker_thread = threading.Thread(target=self._run_loop, name="prepare-bert-batch-worker", daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _estimate_task_tokens(self, task: BertFeatureTask) -> int:
|
||||
return max(1, len(task.norm_text) + 2)
|
||||
|
||||
def submit(self, norm_text: str, word2ph: List[int]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
task = BertFeatureTask(norm_text=str(norm_text), word2ph=list(word2ph))
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.result_feature is not None
|
||||
return task.result_feature, dict(task.profile)
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
"pending": len(self.pending_tasks),
|
||||
"pending_peak": self.pending_peak,
|
||||
"total_submitted": self.total_submitted,
|
||||
"total_finished": self.total_finished,
|
||||
"total_batches": self.total_batches,
|
||||
"active_batch_size": self.active_batch_size,
|
||||
"active_batch_peak": self.active_batch_peak,
|
||||
"batch_window_ms": int(self.batch_window_s * 1000.0),
|
||||
"max_batch_items": self.max_batch_items,
|
||||
"max_batch_tokens": self.max_batch_tokens,
|
||||
}
|
||||
|
||||
def _collect_batch(self) -> List[BertFeatureTask]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
batch: List[BertFeatureTask] = [self.pending_tasks.popleft()]
|
||||
batch_tokens = self._estimate_task_tokens(batch[0])
|
||||
deadline = time.perf_counter() + self.batch_window_s
|
||||
|
||||
while len(batch) < self.max_batch_items:
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.condition.wait(timeout=remaining)
|
||||
continue
|
||||
next_task = self.pending_tasks[0]
|
||||
next_tokens = self._estimate_task_tokens(next_task)
|
||||
if len(batch) >= self.max_batch_items or (batch_tokens + next_tokens) > self.max_batch_tokens:
|
||||
break
|
||||
batch.append(self.pending_tasks.popleft())
|
||||
batch_tokens += next_tokens
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
if self.active_batch_size > self.active_batch_peak:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
return batch
|
||||
|
||||
def _finalize_batch(self, batch: List[BertFeatureTask]) -> None:
|
||||
with self.condition:
|
||||
self.active_batch_size = 0
|
||||
self.total_batches += 1
|
||||
self.total_finished += len(batch)
|
||||
|
||||
def _run_batch(self, batch: List[BertFeatureTask]) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
texts = [task.norm_text for task in batch]
|
||||
batch_tokens = sum(self._estimate_task_tokens(task) for task in batch)
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
if self.stage_limiter is None:
|
||||
tokenize_start = time.perf_counter()
|
||||
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
||||
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
|
||||
attention_mask_cpu = inputs["attention_mask"].cpu()
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.device)
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.stage_limiter.enter() as limiter_stats:
|
||||
tokenize_start = time.perf_counter()
|
||||
inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
||||
tokenize_ms = (time.perf_counter() - tokenize_start) * 1000.0
|
||||
attention_mask_cpu = inputs["attention_mask"].cpu()
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.device)
|
||||
forward_start = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
|
||||
hidden = outputs["hidden_states"][-3].detach().cpu()
|
||||
scatter_start = time.perf_counter()
|
||||
for batch_index, task in enumerate(batch):
|
||||
try:
|
||||
text_len = len(task.word2ph)
|
||||
if text_len != len(task.norm_text):
|
||||
raise AssertionError(
|
||||
f"word2ph/text length mismatch: task={task.task_id} word2ph={text_len} text={len(task.norm_text)}"
|
||||
)
|
||||
seq_len = int(attention_mask_cpu[batch_index].sum().item())
|
||||
char_features = hidden[batch_index, 1 : seq_len - 1]
|
||||
if char_features.shape[0] != text_len:
|
||||
raise AssertionError(
|
||||
f"bert token length mismatch: task={task.task_id} token_len={char_features.shape[0]} text_len={text_len}"
|
||||
)
|
||||
phone_level_feature = []
|
||||
for char_index, repeat_count in enumerate(task.word2ph):
|
||||
phone_level_feature.append(char_features[char_index].repeat(repeat_count, 1))
|
||||
task.result_feature = torch.cat(phone_level_feature, dim=0).T
|
||||
task.profile = {
|
||||
"bert_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
||||
"bert_forward_ms": float(forward_ms),
|
||||
"bert_tokenize_ms": float(tokenize_ms),
|
||||
"bert_scatter_ms": 0.0,
|
||||
"bert_calls": 1.0,
|
||||
"bert_stage_slots": float(limiter_stats["slots"]),
|
||||
"bert_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"bert_batch_size": float(len(batch)),
|
||||
"bert_batch_tokens": float(batch_tokens),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
|
||||
for task in batch:
|
||||
if task.result_feature is not None:
|
||||
task.profile["bert_scatter_ms"] = float(scatter_ms)
|
||||
task.done_event.set()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
task.done_event.set()
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
262
GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py
Normal file
262
GPT_SoVITS/TTS_infer_pack/prepare_ref_semantic_batch_worker.py
Normal file
@ -0,0 +1,262 @@
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Deque, Dict, List, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
REF_AUDIO_MIN_SAMPLES_16K = 48000
|
||||
REF_AUDIO_MAX_SAMPLES_16K = 160000
|
||||
|
||||
|
||||
def prepare_prompt_semantic_wav16k(raw_audio: torch.Tensor, raw_sr: int, zero_wav_samples: int) -> torch.Tensor:
|
||||
wav_mono = raw_audio
|
||||
if wav_mono.dim() == 2 and wav_mono.shape[0] != 1:
|
||||
wav_mono = wav_mono.mean(0, keepdim=True)
|
||||
wav16k = wav_mono.squeeze(0).cpu().numpy()
|
||||
if raw_sr != 16000:
|
||||
wav16k = librosa.resample(wav16k, orig_sr=raw_sr, target_sr=16000)
|
||||
if wav16k.shape[0] > REF_AUDIO_MAX_SAMPLES_16K or wav16k.shape[0] < REF_AUDIO_MIN_SAMPLES_16K:
|
||||
raise OSError("参考音频在3~10秒范围外,请更换!")
|
||||
wav16k = np.ascontiguousarray(wav16k, dtype=np.float32)
|
||||
if zero_wav_samples > 0:
|
||||
wav16k = np.concatenate([wav16k, np.zeros(int(zero_wav_samples), dtype=np.float32)], axis=0)
|
||||
return torch.from_numpy(wav16k)
|
||||
|
||||
|
||||
def conv1d_output_lengths(input_lengths: torch.Tensor, conv1d: torch.nn.Conv1d | None) -> torch.Tensor:
|
||||
if conv1d is None:
|
||||
return input_lengths.to(dtype=torch.long)
|
||||
kernel_size = int(conv1d.kernel_size[0])
|
||||
stride = int(conv1d.stride[0])
|
||||
padding = int(conv1d.padding[0])
|
||||
dilation = int(conv1d.dilation[0])
|
||||
output_lengths = torch.div(
|
||||
input_lengths + 2 * padding - dilation * (kernel_size - 1) - 1,
|
||||
stride,
|
||||
rounding_mode="floor",
|
||||
) + 1
|
||||
return torch.clamp(output_lengths, min=0).to(dtype=torch.long)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RefSemanticTask:
|
||||
raw_audio: torch.Tensor
|
||||
raw_sr: int
|
||||
task_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
created_at: float = field(default_factory=time.perf_counter)
|
||||
done_event: threading.Event = field(default_factory=threading.Event)
|
||||
result_prompt_semantic: torch.Tensor | None = None
|
||||
error: Exception | None = None
|
||||
profile: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PrepareRefSemanticBatchWorker:
|
||||
def __init__(
|
||||
self,
|
||||
ssl_model,
|
||||
vits_model,
|
||||
device,
|
||||
is_half: bool,
|
||||
zero_wav_samples: int,
|
||||
stage_limiter=None,
|
||||
batch_window_ms: int = 5,
|
||||
max_batch_items: int = 8,
|
||||
max_batch_samples: int = 960000,
|
||||
):
|
||||
self.ssl_model = ssl_model
|
||||
self.vits_model = vits_model
|
||||
self.device = device
|
||||
self.is_half = bool(is_half)
|
||||
self.zero_wav_samples = max(0, int(zero_wav_samples))
|
||||
self.stage_limiter = stage_limiter
|
||||
self.batch_window_s = max(0.0, float(batch_window_ms) / 1000.0)
|
||||
self.max_batch_items = max(1, int(max_batch_items))
|
||||
self.max_batch_samples = max(REF_AUDIO_MIN_SAMPLES_16K + self.zero_wav_samples, int(max_batch_samples))
|
||||
|
||||
self.condition = threading.Condition()
|
||||
self.pending_tasks: Deque[RefSemanticTask] = deque()
|
||||
self.pending_peak = 0
|
||||
self.total_submitted = 0
|
||||
self.total_finished = 0
|
||||
self.total_batches = 0
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_peak = 0
|
||||
self.active_batch_samples = 0
|
||||
self.active_batch_samples_peak = 0
|
||||
self.worker_thread = threading.Thread(
|
||||
target=self._run_loop,
|
||||
name="prepare-ref-semantic-batch-worker",
|
||||
daemon=True,
|
||||
)
|
||||
self.worker_thread.start()
|
||||
|
||||
def _estimate_task_samples(self, task: RefSemanticTask) -> int:
|
||||
raw_len = int(task.raw_audio.shape[-1]) if task.raw_audio.dim() > 0 else 0
|
||||
base = int(round(raw_len * 16000.0 / max(1, int(task.raw_sr))))
|
||||
return max(REF_AUDIO_MIN_SAMPLES_16K, base) + self.zero_wav_samples
|
||||
|
||||
def submit(self, raw_audio: torch.Tensor, raw_sr: int) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
task = RefSemanticTask(raw_audio=raw_audio, raw_sr=int(raw_sr))
|
||||
with self.condition:
|
||||
self.pending_tasks.append(task)
|
||||
self.total_submitted += 1
|
||||
if len(self.pending_tasks) > self.pending_peak:
|
||||
self.pending_peak = len(self.pending_tasks)
|
||||
self.condition.notify_all()
|
||||
task.done_event.wait()
|
||||
if task.error is not None:
|
||||
raise task.error
|
||||
assert task.result_prompt_semantic is not None
|
||||
return task.result_prompt_semantic, dict(task.profile)
|
||||
|
||||
def snapshot(self) -> Dict[str, int]:
|
||||
with self.condition:
|
||||
return {
|
||||
"pending": len(self.pending_tasks),
|
||||
"pending_peak": self.pending_peak,
|
||||
"total_submitted": self.total_submitted,
|
||||
"total_finished": self.total_finished,
|
||||
"total_batches": self.total_batches,
|
||||
"active_batch_size": self.active_batch_size,
|
||||
"active_batch_peak": self.active_batch_peak,
|
||||
"active_batch_samples": self.active_batch_samples,
|
||||
"active_batch_samples_peak": self.active_batch_samples_peak,
|
||||
"batch_window_ms": int(self.batch_window_s * 1000.0),
|
||||
"max_batch_items": self.max_batch_items,
|
||||
"max_batch_samples": self.max_batch_samples,
|
||||
}
|
||||
|
||||
def _collect_batch(self) -> List[RefSemanticTask]:
|
||||
with self.condition:
|
||||
while not self.pending_tasks:
|
||||
self.condition.wait()
|
||||
|
||||
batch: List[RefSemanticTask] = [self.pending_tasks.popleft()]
|
||||
batch_samples = self._estimate_task_samples(batch[0])
|
||||
deadline = time.perf_counter() + self.batch_window_s
|
||||
|
||||
while len(batch) < self.max_batch_items:
|
||||
remaining = deadline - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if not self.pending_tasks:
|
||||
self.condition.wait(timeout=remaining)
|
||||
continue
|
||||
next_task = self.pending_tasks[0]
|
||||
next_samples = self._estimate_task_samples(next_task)
|
||||
if len(batch) >= self.max_batch_items or (batch_samples + next_samples) > self.max_batch_samples:
|
||||
break
|
||||
batch.append(self.pending_tasks.popleft())
|
||||
batch_samples += next_samples
|
||||
|
||||
self.active_batch_size = len(batch)
|
||||
self.active_batch_samples = batch_samples
|
||||
if self.active_batch_size > self.active_batch_peak:
|
||||
self.active_batch_peak = self.active_batch_size
|
||||
if self.active_batch_samples > self.active_batch_samples_peak:
|
||||
self.active_batch_samples_peak = self.active_batch_samples
|
||||
return batch
|
||||
|
||||
def _finalize_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||
with self.condition:
|
||||
self.active_batch_size = 0
|
||||
self.active_batch_samples = 0
|
||||
self.total_batches += 1
|
||||
self.total_finished += len(batch)
|
||||
|
||||
def _get_hidden_lengths(self, attention_mask: torch.Tensor, hidden_length: int) -> torch.Tensor:
|
||||
model = self.ssl_model.model
|
||||
if hasattr(model, "_get_feature_vector_attention_mask"):
|
||||
feature_mask = model._get_feature_vector_attention_mask(hidden_length, attention_mask)
|
||||
return feature_mask.to(dtype=torch.long).sum(dim=1)
|
||||
raw_lengths = attention_mask.to(dtype=torch.long).sum(dim=1)
|
||||
if hasattr(model, "_get_feat_extract_output_lengths"):
|
||||
return model._get_feat_extract_output_lengths(raw_lengths).to(dtype=torch.long)
|
||||
return torch.full((attention_mask.shape[0],), int(hidden_length), dtype=torch.long, device=attention_mask.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _run_batch(self, batch: List[RefSemanticTask]) -> None:
|
||||
batch_started = time.perf_counter()
|
||||
prepared_start = time.perf_counter()
|
||||
prepared_wavs = [
|
||||
prepare_prompt_semantic_wav16k(task.raw_audio, int(task.raw_sr), self.zero_wav_samples) for task in batch
|
||||
]
|
||||
cpu_prepare_ms = (time.perf_counter() - prepared_start) * 1000.0
|
||||
wav_lengths = torch.tensor([int(wav.shape[0]) for wav in prepared_wavs], dtype=torch.long)
|
||||
batch_samples = int(wav_lengths.sum().item())
|
||||
max_wav_len = int(wav_lengths.max().item())
|
||||
|
||||
input_values_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
|
||||
attention_mask_cpu = torch.zeros((len(batch), max_wav_len), dtype=torch.long)
|
||||
for batch_index, wav in enumerate(prepared_wavs):
|
||||
wav_len = int(wav.shape[0])
|
||||
input_values_cpu[batch_index, :wav_len] = wav
|
||||
attention_mask_cpu[batch_index, :wav_len] = 1
|
||||
|
||||
limiter_stats = {"wait_ms": 0.0, "peak_inflight": 1, "slots": 0}
|
||||
if self.stage_limiter is None:
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
forward_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
else:
|
||||
with self.stage_limiter.enter() as limiter_stats:
|
||||
input_values = input_values_cpu.to(self.device)
|
||||
attention_mask = attention_mask_cpu.to(self.device)
|
||||
if self.is_half:
|
||||
input_values = input_values.half()
|
||||
forward_start = time.perf_counter()
|
||||
outputs = self.ssl_model.model(input_values, attention_mask=attention_mask)
|
||||
hubert_feature = outputs["last_hidden_state"].transpose(1, 2)
|
||||
hidden_lengths = self._get_hidden_lengths(attention_mask, int(hubert_feature.shape[-1]))
|
||||
codes = self.vits_model.extract_latent(hubert_feature)
|
||||
forward_ms = (time.perf_counter() - forward_start) * 1000.0
|
||||
|
||||
code_lengths = conv1d_output_lengths(hidden_lengths.detach().cpu(), getattr(self.vits_model, "ssl_proj", None))
|
||||
scatter_start = time.perf_counter()
|
||||
for batch_index, task in enumerate(batch):
|
||||
try:
|
||||
code_len = int(code_lengths[batch_index].item())
|
||||
task.result_prompt_semantic = codes[batch_index, 0, :code_len].detach().clone()
|
||||
task.profile = {
|
||||
"prompt_semantic_wait_ms": (batch_started - task.created_at) * 1000.0 + float(limiter_stats["wait_ms"]),
|
||||
"prompt_semantic_cpu_prepare_ms": float(cpu_prepare_ms),
|
||||
"prompt_semantic_forward_ms": float(forward_ms),
|
||||
"prompt_semantic_scatter_ms": 0.0,
|
||||
"prompt_semantic_calls": 1.0,
|
||||
"prompt_semantic_stage_slots": float(limiter_stats["slots"]),
|
||||
"prompt_semantic_stage_inflight_peak": float(limiter_stats["peak_inflight"]),
|
||||
"prompt_semantic_batch_size": float(len(batch)),
|
||||
"prompt_semantic_batch_samples": float(batch_samples),
|
||||
}
|
||||
except Exception as exc: # noqa: PERF203
|
||||
task.error = exc
|
||||
scatter_ms = (time.perf_counter() - scatter_start) * 1000.0
|
||||
for task in batch:
|
||||
if task.result_prompt_semantic is not None:
|
||||
task.profile["prompt_semantic_scatter_ms"] = float(scatter_ms)
|
||||
task.done_event.set()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
while True:
|
||||
batch = self._collect_batch()
|
||||
try:
|
||||
self._run_batch(batch)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
for task in batch:
|
||||
task.error = exc
|
||||
task.done_event.set()
|
||||
finally:
|
||||
self._finalize_batch(batch)
|
||||
734
GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py
Normal file
734
GPT_SoVITS/TTS_infer_pack/t2s_scheduler.py
Normal file
@ -0,0 +1,734 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from AR.models.utils import make_pad_mask_left, sample
|
||||
|
||||
|
||||
def _sync_device(device: Any) -> None:
|
||||
try:
|
||||
device_str = str(device)
|
||||
if device_str.startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize(device)
|
||||
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerRequestSpec:
|
||||
request_id: str
|
||||
ref_audio_path: Path
|
||||
prompt_text: str
|
||||
prompt_lang: str
|
||||
text: str
|
||||
text_lang: str
|
||||
top_k: int
|
||||
top_p: float
|
||||
temperature: float
|
||||
repetition_penalty: float
|
||||
early_stop_num: int
|
||||
ready_step: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SRequestState:
|
||||
request_id: str
|
||||
ref_audio_path: Path
|
||||
prompt_text: str
|
||||
prompt_lang: str
|
||||
text: str
|
||||
text_lang: str
|
||||
norm_prompt_text: str
|
||||
norm_text: str
|
||||
phones: torch.LongTensor
|
||||
prompt_phones: torch.LongTensor
|
||||
all_phones: torch.LongTensor
|
||||
all_bert_features: torch.Tensor
|
||||
prompt_semantic: torch.LongTensor
|
||||
refer_spec: Tuple[torch.Tensor, Optional[torch.Tensor]]
|
||||
raw_audio: torch.Tensor
|
||||
raw_sr: int
|
||||
top_k: int
|
||||
top_p: float
|
||||
temperature: float
|
||||
repetition_penalty: float
|
||||
early_stop_num: int
|
||||
ready_step: int
|
||||
prepare_profile: Dict[str, float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SRunningRequest:
|
||||
state: T2SRequestState
|
||||
y_sequence: torch.LongTensor
|
||||
prefix_len: int
|
||||
decode_attn_mask: Optional[torch.Tensor]
|
||||
k_cache: List[torch.Tensor]
|
||||
v_cache: List[torch.Tensor]
|
||||
step_idx: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SFinishedItem:
|
||||
request_id: str
|
||||
semantic_tokens: torch.LongTensor
|
||||
finish_idx: int
|
||||
finish_reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SActiveBatch:
|
||||
request_ids: List[str]
|
||||
states: List[T2SRequestState]
|
||||
x: torch.Tensor
|
||||
x_lens: torch.LongTensor
|
||||
y_sequences: List[torch.LongTensor]
|
||||
prefix_lens: torch.LongTensor
|
||||
xy_pos: torch.Tensor
|
||||
key_padding_mask: torch.Tensor
|
||||
prefill_attn_mask: torch.Tensor
|
||||
decode_attn_mask: Optional[torch.Tensor]
|
||||
k_cache: Optional[List[torch.Tensor]]
|
||||
v_cache: Optional[List[torch.Tensor]]
|
||||
step_idx: int
|
||||
prefill_done: bool
|
||||
|
||||
|
||||
def normalize_sentence(text: str, language: str) -> str:
|
||||
text = text.strip("\n").strip()
|
||||
if not text:
|
||||
return text
|
||||
if text[-1] not in {",", ".", "?", "!", ",", "。", "?", "!", "…", ";", ";", ":"}:
|
||||
text += "。" if language != "en" else "."
|
||||
return text
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_request_state(
|
||||
tts: Any,
|
||||
spec: SchedulerRequestSpec,
|
||||
) -> T2SRequestState:
|
||||
device = tts.configs.device
|
||||
prepare_start = time.perf_counter()
|
||||
_sync_device(device)
|
||||
prepare_sync_start = time.perf_counter()
|
||||
prompt_text = normalize_sentence(spec.prompt_text, spec.prompt_lang)
|
||||
text = spec.text.strip("\n")
|
||||
|
||||
prompt_text_profile: Dict[str, float] = {}
|
||||
text_features_profile: Dict[str, float] = {}
|
||||
text_feature_pair_start = time.perf_counter()
|
||||
prompt_future: Future | None = None
|
||||
|
||||
def _extract_prompt_features():
|
||||
_sync_device(device)
|
||||
prompt_start = time.perf_counter()
|
||||
result = tts.extract_text_features(prompt_text, spec.prompt_lang, profile=prompt_text_profile)
|
||||
_sync_device(device)
|
||||
return result, (time.perf_counter() - prompt_start) * 1000.0
|
||||
|
||||
if getattr(tts, "prepare_text_cpu_executor", None) is not None:
|
||||
prompt_future = tts.prepare_text_cpu_executor.submit(_extract_prompt_features)
|
||||
|
||||
_sync_device(device)
|
||||
text_features_start = time.perf_counter()
|
||||
phones, bert_features, norm_text = tts.extract_text_features(text, spec.text_lang, profile=text_features_profile)
|
||||
_sync_device(device)
|
||||
text_features_ms = (time.perf_counter() - text_features_start) * 1000.0
|
||||
|
||||
if prompt_future is None:
|
||||
_sync_device(device)
|
||||
prompt_text_features_start = time.perf_counter()
|
||||
prompt_phones, prompt_bert_features, prompt_norm_text = tts.extract_text_features(
|
||||
prompt_text, spec.prompt_lang, profile=prompt_text_profile
|
||||
)
|
||||
_sync_device(device)
|
||||
prompt_text_features_ms = (time.perf_counter() - prompt_text_features_start) * 1000.0
|
||||
prompt_text_profile["parallel_future_wait_ms"] = 0.0
|
||||
else:
|
||||
prompt_wait_start = time.perf_counter()
|
||||
(prompt_phones, prompt_bert_features, prompt_norm_text), prompt_text_features_ms = prompt_future.result()
|
||||
prompt_text_profile["parallel_future_wait_ms"] = (time.perf_counter() - prompt_wait_start) * 1000.0
|
||||
|
||||
text_feature_pair_ms = (time.perf_counter() - text_feature_pair_start) * 1000.0
|
||||
if phones is None:
|
||||
raise ValueError(f"{spec.request_id} text preprocessing returned no phones")
|
||||
|
||||
_sync_device(device)
|
||||
ref_audio_bundle_start = time.perf_counter()
|
||||
ref_audio_bundle = tts.extract_ref_audio_bundle(str(spec.ref_audio_path))
|
||||
prompt_semantic = ref_audio_bundle["prompt_semantic"].long()
|
||||
spec_audio, audio_16k = ref_audio_bundle["refer_spec"]
|
||||
raw_audio = ref_audio_bundle["raw_audio"]
|
||||
raw_sr = int(ref_audio_bundle["raw_sr"])
|
||||
_sync_device(device)
|
||||
ref_audio_bundle_ms = (time.perf_counter() - ref_audio_bundle_start) * 1000.0
|
||||
bundle_profile = ref_audio_bundle.get("profile", {})
|
||||
prompt_semantic_ms = float(bundle_profile.get("prompt_semantic_ms", ref_audio_bundle_ms))
|
||||
ref_spec_ms = float(bundle_profile.get("ref_spec_ms", 0.0))
|
||||
audio_load_ms = float(bundle_profile.get("audio_load_ms", 0.0))
|
||||
|
||||
_sync_device(device)
|
||||
tensorize_start = time.perf_counter()
|
||||
phones_tensor = torch.LongTensor(phones).to(tts.configs.device)
|
||||
prompt_phones_tensor = torch.LongTensor(prompt_phones).to(tts.configs.device)
|
||||
all_phones = torch.LongTensor(prompt_phones + phones).to(tts.configs.device)
|
||||
all_bert_features = torch.cat([prompt_bert_features, bert_features], dim=1).to(
|
||||
dtype=tts.precision, device=tts.configs.device
|
||||
)
|
||||
_sync_device(device)
|
||||
tensorize_ms = (time.perf_counter() - tensorize_start) * 1000.0
|
||||
|
||||
_sync_device(device)
|
||||
prepare_profile = {
|
||||
"prompt_text_features_ms": prompt_text_features_ms,
|
||||
"text_features_ms": text_features_ms,
|
||||
"prompt_text_bert_wait_ms": float(prompt_text_profile.get("bert_wait_ms", 0.0)),
|
||||
"prompt_text_bert_forward_ms": float(prompt_text_profile.get("bert_forward_ms", 0.0)),
|
||||
"prompt_text_bert_tokenize_ms": float(prompt_text_profile.get("bert_tokenize_ms", 0.0)),
|
||||
"prompt_text_bert_scatter_ms": float(prompt_text_profile.get("bert_scatter_ms", 0.0)),
|
||||
"prompt_text_bert_calls": float(prompt_text_profile.get("bert_calls", 0.0)),
|
||||
"prompt_text_bert_stage_slots": float(prompt_text_profile.get("bert_stage_slots", 0.0)),
|
||||
"prompt_text_bert_stage_inflight_peak": float(prompt_text_profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"prompt_text_bert_batch_size_peak": float(prompt_text_profile.get("bert_batch_size_peak", 0.0)),
|
||||
"prompt_text_bert_batch_tokens_peak": float(prompt_text_profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"prompt_text_parallel_future_wait_ms": float(prompt_text_profile.get("parallel_future_wait_ms", 0.0)),
|
||||
"text_bert_wait_ms": float(text_features_profile.get("bert_wait_ms", 0.0)),
|
||||
"text_bert_forward_ms": float(text_features_profile.get("bert_forward_ms", 0.0)),
|
||||
"text_bert_tokenize_ms": float(text_features_profile.get("bert_tokenize_ms", 0.0)),
|
||||
"text_bert_scatter_ms": float(text_features_profile.get("bert_scatter_ms", 0.0)),
|
||||
"text_bert_calls": float(text_features_profile.get("bert_calls", 0.0)),
|
||||
"text_bert_stage_slots": float(text_features_profile.get("bert_stage_slots", 0.0)),
|
||||
"text_bert_stage_inflight_peak": float(text_features_profile.get("bert_stage_inflight_peak", 0.0)),
|
||||
"text_bert_batch_size_peak": float(text_features_profile.get("bert_batch_size_peak", 0.0)),
|
||||
"text_bert_batch_tokens_peak": float(text_features_profile.get("bert_batch_tokens_peak", 0.0)),
|
||||
"text_feature_pair_ms": text_feature_pair_ms,
|
||||
"text_cpu_parallel_workers": float(getattr(tts, "prepare_text_cpu_workers", 0)),
|
||||
"audio_load_ms": audio_load_ms,
|
||||
"audio_stage_wait_ms": float(bundle_profile.get("audio_stage_wait_ms", 0.0)),
|
||||
"audio_stage_slots": float(bundle_profile.get("audio_stage_slots", 0.0)),
|
||||
"audio_stage_inflight_peak": float(bundle_profile.get("audio_stage_inflight_peak", 0.0)),
|
||||
"prompt_semantic_ms": prompt_semantic_ms,
|
||||
"prompt_semantic_wait_ms": float(bundle_profile.get("prompt_semantic_wait_ms", 0.0)),
|
||||
"prompt_semantic_cpu_prepare_ms": float(bundle_profile.get("prompt_semantic_cpu_prepare_ms", 0.0)),
|
||||
"prompt_semantic_forward_ms": float(bundle_profile.get("prompt_semantic_forward_ms", 0.0)),
|
||||
"prompt_semantic_scatter_ms": float(bundle_profile.get("prompt_semantic_scatter_ms", 0.0)),
|
||||
"prompt_semantic_stage_slots": float(bundle_profile.get("prompt_semantic_stage_slots", 0.0)),
|
||||
"prompt_semantic_stage_inflight_peak": float(bundle_profile.get("prompt_semantic_stage_inflight_peak", 0.0)),
|
||||
"prompt_semantic_batch_size": float(bundle_profile.get("prompt_semantic_batch_size", 0.0)),
|
||||
"prompt_semantic_batch_samples": float(bundle_profile.get("prompt_semantic_batch_samples", 0.0)),
|
||||
"ref_spec_wait_ms": float(bundle_profile.get("ref_spec_wait_ms", 0.0)),
|
||||
"ref_spec_ms": ref_spec_ms,
|
||||
"ref_audio_bundle_ms": ref_audio_bundle_ms,
|
||||
"tensorize_ms": tensorize_ms,
|
||||
"total_ms": (time.perf_counter() - prepare_sync_start) * 1000.0,
|
||||
"wall_total_ms": (time.perf_counter() - prepare_start) * 1000.0,
|
||||
}
|
||||
return T2SRequestState(
|
||||
request_id=spec.request_id,
|
||||
ref_audio_path=spec.ref_audio_path,
|
||||
prompt_text=prompt_text,
|
||||
prompt_lang=spec.prompt_lang,
|
||||
text=text,
|
||||
text_lang=spec.text_lang,
|
||||
norm_prompt_text=prompt_norm_text,
|
||||
norm_text=norm_text,
|
||||
phones=phones_tensor,
|
||||
prompt_phones=prompt_phones_tensor,
|
||||
all_phones=all_phones,
|
||||
all_bert_features=all_bert_features,
|
||||
prompt_semantic=prompt_semantic,
|
||||
refer_spec=(spec_audio, audio_16k),
|
||||
raw_audio=raw_audio,
|
||||
raw_sr=raw_sr,
|
||||
top_k=spec.top_k,
|
||||
top_p=spec.top_p,
|
||||
temperature=spec.temperature,
|
||||
repetition_penalty=spec.repetition_penalty,
|
||||
early_stop_num=spec.early_stop_num,
|
||||
ready_step=spec.ready_step,
|
||||
prepare_profile=prepare_profile,
|
||||
)
|
||||
|
||||
|
||||
def _left_pad_hidden(hidden: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
if hidden.shape[0] >= target_len:
|
||||
return hidden
|
||||
return F.pad(hidden, (0, 0, target_len - hidden.shape[0], 0), value=0)
|
||||
|
||||
|
||||
def _ensure_audio_pe(model: Any, max_position: int, dtype: torch.dtype, device: torch.device) -> None:
|
||||
required_len = max_position + 1
|
||||
if model.ar_audio_position.pe is not None and model.ar_audio_position.pe.size(1) >= required_len:
|
||||
if model.ar_audio_position.pe.dtype != dtype or model.ar_audio_position.pe.device != device:
|
||||
model.ar_audio_position.pe = model.ar_audio_position.pe.to(dtype=dtype, device=device)
|
||||
return
|
||||
model.ar_audio_position.extend_pe(
|
||||
torch.zeros(1, required_len, model.ar_audio_position.embedding_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prefill_batch(model: Any, states: Sequence[T2SRequestState]) -> T2SActiveBatch:
|
||||
x_items: List[torch.Tensor] = []
|
||||
y_pos_items: List[torch.Tensor] = []
|
||||
x_lens: List[int] = []
|
||||
prefix_lens: List[int] = []
|
||||
y_sequences: List[torch.LongTensor] = []
|
||||
|
||||
for state in states:
|
||||
text_emb = model.ar_text_embedding(state.all_phones.unsqueeze(0))
|
||||
bert_proj = model.bert_proj(state.all_bert_features.transpose(0, 1).unsqueeze(0))
|
||||
x_pos = model.ar_text_position(text_emb + bert_proj).squeeze(0)
|
||||
y_emb = model.ar_audio_embedding(state.prompt_semantic.unsqueeze(0))
|
||||
y_pos = model.ar_audio_position(y_emb).squeeze(0)
|
||||
x_items.append(x_pos)
|
||||
y_pos_items.append(y_pos)
|
||||
x_lens.append(x_pos.shape[0])
|
||||
prefix_lens.append(y_pos.shape[0])
|
||||
y_sequences.append(state.prompt_semantic.clone())
|
||||
|
||||
max_x_len = max(x_lens)
|
||||
max_prefix_len = max(prefix_lens)
|
||||
x_batch = torch.stack([_left_pad_hidden(item, max_x_len) for item in x_items], dim=0)
|
||||
y_pos_batch = torch.stack([_left_pad_hidden(item, max_prefix_len) for item in y_pos_items], dim=0)
|
||||
xy_pos = torch.cat([x_batch, y_pos_batch], dim=1)
|
||||
|
||||
device = x_batch.device
|
||||
x_lens_tensor = torch.LongTensor(x_lens).to(device)
|
||||
prefix_lens_tensor = torch.LongTensor(prefix_lens).to(device)
|
||||
src_len = max_x_len + max_prefix_len
|
||||
|
||||
x_padding_mask = make_pad_mask_left(x_lens_tensor, max_x_len)
|
||||
y_padding_mask = make_pad_mask_left(prefix_lens_tensor, max_prefix_len)
|
||||
key_padding_mask = torch.cat([x_padding_mask, y_padding_mask], dim=1).bool()
|
||||
x_mask = F.pad(torch.zeros(max_x_len, max_x_len, dtype=torch.bool, device=device), (0, max_prefix_len), value=True)
|
||||
y_mask = F.pad(
|
||||
torch.triu(torch.ones(max_prefix_len, max_prefix_len, dtype=torch.bool, device=device), diagonal=1),
|
||||
(max_x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
causal_mask = torch.cat([x_mask, y_mask], dim=0).unsqueeze(0)
|
||||
attn_mask = causal_mask.logical_or(key_padding_mask.unsqueeze(1)).unsqueeze(1)
|
||||
|
||||
return T2SActiveBatch(
|
||||
request_ids=[state.request_id for state in states],
|
||||
states=list(states),
|
||||
x=x_batch,
|
||||
x_lens=x_lens_tensor,
|
||||
y_sequences=y_sequences,
|
||||
prefix_lens=prefix_lens_tensor,
|
||||
xy_pos=xy_pos,
|
||||
key_padding_mask=key_padding_mask,
|
||||
prefill_attn_mask=attn_mask,
|
||||
decode_attn_mask=None,
|
||||
k_cache=None,
|
||||
v_cache=None,
|
||||
step_idx=0,
|
||||
prefill_done=False,
|
||||
)
|
||||
|
||||
|
||||
def build_next_xy_pos(model: Any, y_sequences: Sequence[torch.LongTensor]) -> torch.Tensor:
|
||||
last_tokens = torch.stack([seq[-1:] for seq in y_sequences], dim=0)
|
||||
y_emb = model.ar_audio_embedding(last_tokens)
|
||||
position_ids = torch.LongTensor([int(seq.shape[0] - 1) for seq in y_sequences]).to(y_emb.device)
|
||||
_ensure_audio_pe(model, int(position_ids.max().item()), y_emb.dtype, y_emb.device)
|
||||
pos_emb = model.ar_audio_position.pe[0].index_select(0, position_ids).unsqueeze(1)
|
||||
return y_emb * model.ar_audio_position.x_scale + model.ar_audio_position.alpha * pos_emb.to(
|
||||
dtype=y_emb.dtype, device=y_emb.device
|
||||
)
|
||||
|
||||
|
||||
def _sample_per_request(
|
||||
model: Any,
|
||||
active_batch: T2SActiveBatch,
|
||||
logits: torch.Tensor,
|
||||
max_steps: int,
|
||||
) -> Tuple[List[T2SFinishedItem], List[int], List[torch.LongTensor]]:
|
||||
finished_items: List[T2SFinishedItem] = []
|
||||
keep_indices: List[int] = []
|
||||
updated_sequences: List[torch.LongTensor] = []
|
||||
|
||||
step_idx = active_batch.step_idx
|
||||
for batch_index, state in enumerate(active_batch.states):
|
||||
logits_i = logits[batch_index : batch_index + 1].clone()
|
||||
current_history = active_batch.y_sequences[batch_index]
|
||||
sampled = sample(
|
||||
logits_i,
|
||||
current_history.unsqueeze(0),
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
temperature=state.temperature,
|
||||
)[0]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = int(torch.argmax(logits[batch_index], dim=-1).item())
|
||||
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
if state.early_stop_num != -1 and (new_history.shape[0] - int(active_batch.prefix_lens[batch_index].item())) > state.early_stop_num:
|
||||
finish_reason = "early_stop"
|
||||
elif step_idx + 1 >= max_steps:
|
||||
finish_reason = "max_step"
|
||||
elif sampled_token == model.EOS:
|
||||
finish_reason = "eos_sample"
|
||||
elif argmax_token == model.EOS:
|
||||
finish_reason = "eos_argmax"
|
||||
|
||||
if finish_reason is not None:
|
||||
prefix_len = int(active_batch.prefix_lens[batch_index].item())
|
||||
finished_items.append(
|
||||
T2SFinishedItem(
|
||||
request_id=state.request_id,
|
||||
semantic_tokens=new_history[prefix_len:-1].clone(),
|
||||
finish_idx=step_idx,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
else:
|
||||
keep_indices.append(batch_index)
|
||||
updated_sequences.append(new_history)
|
||||
|
||||
return finished_items, keep_indices, updated_sequences
|
||||
|
||||
|
||||
def decode_one_step(
|
||||
model: Any,
|
||||
active_batch: T2SActiveBatch,
|
||||
max_steps: int,
|
||||
) -> Tuple[Optional[T2SActiveBatch], List[T2SFinishedItem]]:
|
||||
if not active_batch.prefill_done:
|
||||
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.process_prompt(
|
||||
active_batch.xy_pos, active_batch.prefill_attn_mask, None
|
||||
)
|
||||
active_batch.decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
|
||||
active_batch.prefill_done = True
|
||||
else:
|
||||
xy_dec, active_batch.k_cache, active_batch.v_cache = model.t2s_transformer.decode_next_token(
|
||||
active_batch.xy_pos,
|
||||
active_batch.k_cache,
|
||||
active_batch.v_cache,
|
||||
active_batch.decode_attn_mask,
|
||||
)
|
||||
if active_batch.decode_attn_mask is not None:
|
||||
active_batch.decode_attn_mask = F.pad(active_batch.decode_attn_mask, (0, 1), value=False)
|
||||
|
||||
logits = model.ar_predict_layer(xy_dec[:, -1])
|
||||
if active_batch.step_idx < 11:
|
||||
logits = logits[:, :-1]
|
||||
|
||||
finished_items, keep_indices, updated_sequences = _sample_per_request(model, active_batch, logits, max_steps=max_steps)
|
||||
if len(keep_indices) == 0:
|
||||
return None, finished_items
|
||||
|
||||
device = logits.device
|
||||
keep_tensor = torch.LongTensor(keep_indices).to(device)
|
||||
active_batch.request_ids = [active_batch.request_ids[i] for i in keep_indices]
|
||||
active_batch.states = [active_batch.states[i] for i in keep_indices]
|
||||
active_batch.y_sequences = updated_sequences
|
||||
active_batch.prefix_lens = torch.index_select(active_batch.prefix_lens, dim=0, index=keep_tensor)
|
||||
|
||||
if active_batch.decode_attn_mask is not None:
|
||||
active_batch.decode_attn_mask = torch.index_select(active_batch.decode_attn_mask, dim=0, index=keep_tensor)
|
||||
if active_batch.k_cache is not None and active_batch.v_cache is not None:
|
||||
for cache_index in range(len(active_batch.k_cache)):
|
||||
active_batch.k_cache[cache_index] = torch.index_select(active_batch.k_cache[cache_index], dim=0, index=keep_tensor)
|
||||
active_batch.v_cache[cache_index] = torch.index_select(active_batch.v_cache[cache_index], dim=0, index=keep_tensor)
|
||||
|
||||
active_batch.xy_pos = build_next_xy_pos(model, active_batch.y_sequences)
|
||||
active_batch.step_idx += 1
|
||||
return active_batch, finished_items
|
||||
|
||||
|
||||
def run_scheduler_batch(
|
||||
model: Any,
|
||||
states: Sequence[T2SRequestState],
|
||||
max_steps: int,
|
||||
) -> List[T2SFinishedItem]:
|
||||
return run_scheduler_continuous(model, states, max_steps=max_steps)
|
||||
|
||||
|
||||
def _pad_cache_left(cache: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
pad_len = target_len - cache.shape[1]
|
||||
if pad_len <= 0:
|
||||
return cache
|
||||
return F.pad(cache, (0, 0, pad_len, 0), value=0)
|
||||
|
||||
|
||||
def _pad_decode_mask_left(mask: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
pad_len = target_len - mask.shape[-1]
|
||||
if pad_len <= 0:
|
||||
return mask
|
||||
return F.pad(mask, (pad_len, 0), value=True)
|
||||
|
||||
|
||||
def _fit_decode_mask_length(mask: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
if mask.shape[-1] > target_len:
|
||||
return mask[:, :, :, -target_len:]
|
||||
if mask.shape[-1] < target_len:
|
||||
return _pad_decode_mask_left(mask, target_len)
|
||||
return mask
|
||||
|
||||
|
||||
def _materialize_decode_mask_for_request(running_request: T2SRunningRequest) -> torch.Tensor:
|
||||
expected_mask_len = running_request.k_cache[0].shape[1] + 1
|
||||
if running_request.decode_attn_mask is not None:
|
||||
return _fit_decode_mask_length(running_request.decode_attn_mask, expected_mask_len)
|
||||
current_mask_len = running_request.k_cache[0].shape[1] + 1
|
||||
return torch.zeros(
|
||||
(1, 1, 1, current_mask_len),
|
||||
dtype=torch.bool,
|
||||
device=running_request.k_cache[0].device,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_step(
|
||||
model: Any,
|
||||
states: Sequence[T2SRequestState],
|
||||
max_steps: int,
|
||||
) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]:
|
||||
if not states:
|
||||
return [], []
|
||||
|
||||
active_batch = build_prefill_batch(model, states)
|
||||
xy_dec, k_cache, v_cache = model.t2s_transformer.process_prompt(active_batch.xy_pos, active_batch.prefill_attn_mask, None)
|
||||
decode_attn_mask = F.pad(active_batch.key_padding_mask.unsqueeze(1).unsqueeze(1), (0, 1), value=False)
|
||||
if len(states) == 1 and not decode_attn_mask.any().item():
|
||||
decode_attn_mask = None
|
||||
logits = model.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
finished_items: List[T2SFinishedItem] = []
|
||||
|
||||
for batch_index, state in enumerate(states):
|
||||
logits_i = logits[batch_index : batch_index + 1].clone()
|
||||
if 0 < 11:
|
||||
logits_i = logits_i[:, :-1]
|
||||
current_history = active_batch.y_sequences[batch_index]
|
||||
sampled = sample(
|
||||
logits_i,
|
||||
current_history.unsqueeze(0),
|
||||
top_k=state.top_k,
|
||||
top_p=state.top_p,
|
||||
repetition_penalty=state.repetition_penalty,
|
||||
temperature=state.temperature,
|
||||
)[0]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
|
||||
new_history = torch.cat([current_history, sampled.view(-1)], dim=0)
|
||||
prefix_len = int(active_batch.prefix_lens[batch_index].item())
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
if state.early_stop_num != -1 and (new_history.shape[0] - prefix_len) > state.early_stop_num:
|
||||
finish_reason = "early_stop"
|
||||
elif 1 >= max_steps:
|
||||
finish_reason = "max_step"
|
||||
elif sampled_token == model.EOS:
|
||||
finish_reason = "eos_sample"
|
||||
elif argmax_token == model.EOS:
|
||||
finish_reason = "eos_argmax"
|
||||
|
||||
if finish_reason is not None:
|
||||
finished_items.append(
|
||||
T2SFinishedItem(
|
||||
request_id=state.request_id,
|
||||
semantic_tokens=new_history[prefix_len:-1].clone(),
|
||||
finish_idx=0,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
real_kv_len = int(active_batch.x_lens[batch_index].item()) + prefix_len
|
||||
request_k_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in k_cache]
|
||||
request_v_cache = [layer[batch_index : batch_index + 1, -real_kv_len:, :].clone() for layer in v_cache]
|
||||
request_decode_attn_mask = None
|
||||
if decode_attn_mask is not None:
|
||||
request_decode_attn_mask = decode_attn_mask[batch_index : batch_index + 1].clone()
|
||||
request_decode_attn_mask = _fit_decode_mask_length(request_decode_attn_mask, real_kv_len + 1)
|
||||
if not request_decode_attn_mask.any().item():
|
||||
request_decode_attn_mask = None
|
||||
|
||||
running_requests.append(
|
||||
T2SRunningRequest(
|
||||
state=state,
|
||||
y_sequence=new_history,
|
||||
prefix_len=prefix_len,
|
||||
decode_attn_mask=request_decode_attn_mask,
|
||||
k_cache=request_k_cache,
|
||||
v_cache=request_v_cache,
|
||||
step_idx=1,
|
||||
)
|
||||
)
|
||||
|
||||
return running_requests, finished_items
|
||||
|
||||
|
||||
def _build_decode_batch_from_running(
|
||||
model: Any,
|
||||
running_requests: Sequence[T2SRunningRequest],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], Optional[torch.Tensor]]:
|
||||
xy_pos = build_next_xy_pos(model, [item.y_sequence for item in running_requests])
|
||||
max_kv_len = max(item.k_cache[0].shape[1] for item in running_requests)
|
||||
num_layers = len(running_requests[0].k_cache)
|
||||
|
||||
batched_k_cache: List[torch.Tensor] = []
|
||||
batched_v_cache: List[torch.Tensor] = []
|
||||
for layer_index in range(num_layers):
|
||||
batched_k_cache.append(
|
||||
torch.cat([_pad_cache_left(item.k_cache[layer_index], max_kv_len) for item in running_requests], dim=0)
|
||||
)
|
||||
batched_v_cache.append(
|
||||
torch.cat([_pad_cache_left(item.v_cache[layer_index], max_kv_len) for item in running_requests], dim=0)
|
||||
)
|
||||
|
||||
if all(item.decode_attn_mask is None for item in running_requests):
|
||||
batched_decode_attn_mask = None
|
||||
else:
|
||||
materialized_masks = [_materialize_decode_mask_for_request(item) for item in running_requests]
|
||||
max_mask_len = max(mask.shape[-1] for mask in materialized_masks)
|
||||
batched_decode_attn_mask = torch.cat(
|
||||
[_pad_decode_mask_left(mask, max_mask_len) for mask in materialized_masks],
|
||||
dim=0,
|
||||
)
|
||||
return xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_decode_step_for_running(
|
||||
model: Any,
|
||||
running_requests: Sequence[T2SRunningRequest],
|
||||
max_steps: int,
|
||||
) -> Tuple[List[T2SRunningRequest], List[T2SFinishedItem]]:
|
||||
if not running_requests:
|
||||
return [], []
|
||||
|
||||
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = _build_decode_batch_from_running(
|
||||
model, running_requests
|
||||
)
|
||||
xy_dec, next_k_cache, next_v_cache = model.t2s_transformer.decode_next_token(
|
||||
xy_pos,
|
||||
batched_k_cache,
|
||||
batched_v_cache,
|
||||
batched_decode_attn_mask,
|
||||
)
|
||||
logits = model.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
next_running: List[T2SRunningRequest] = []
|
||||
finished_items: List[T2SFinishedItem] = []
|
||||
|
||||
for batch_index, running_request in enumerate(running_requests):
|
||||
current_idx = running_request.step_idx
|
||||
logits_i = logits[batch_index : batch_index + 1].clone()
|
||||
if current_idx < 11:
|
||||
logits_i = logits_i[:, :-1]
|
||||
sampled = sample(
|
||||
logits_i,
|
||||
running_request.y_sequence.unsqueeze(0),
|
||||
top_k=running_request.state.top_k,
|
||||
top_p=running_request.state.top_p,
|
||||
repetition_penalty=running_request.state.repetition_penalty,
|
||||
temperature=running_request.state.temperature,
|
||||
)[0]
|
||||
sampled_token = int(sampled[0, 0].item())
|
||||
argmax_token = int(torch.argmax(logits_i[0], dim=-1).item())
|
||||
new_history = torch.cat([running_request.y_sequence, sampled.view(-1)], dim=0)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
if running_request.state.early_stop_num != -1 and (new_history.shape[0] - running_request.prefix_len) > running_request.state.early_stop_num:
|
||||
finish_reason = "early_stop"
|
||||
elif current_idx + 1 >= max_steps:
|
||||
finish_reason = "max_step"
|
||||
elif sampled_token == model.EOS:
|
||||
finish_reason = "eos_sample"
|
||||
elif argmax_token == model.EOS:
|
||||
finish_reason = "eos_argmax"
|
||||
|
||||
if finish_reason is not None:
|
||||
finished_items.append(
|
||||
T2SFinishedItem(
|
||||
request_id=running_request.state.request_id,
|
||||
semantic_tokens=new_history[running_request.prefix_len:-1].clone(),
|
||||
finish_idx=current_idx,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
real_next_kv_len = running_request.k_cache[0].shape[1] + 1
|
||||
request_k_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_k_cache]
|
||||
request_v_cache = [layer[batch_index : batch_index + 1, -real_next_kv_len:, :].clone() for layer in next_v_cache]
|
||||
if batched_decode_attn_mask is None:
|
||||
next_decode_attn_mask = None
|
||||
else:
|
||||
current_decode_mask_len = running_request.k_cache[0].shape[1] + 1
|
||||
current_decode_attn_mask = batched_decode_attn_mask[
|
||||
batch_index : batch_index + 1, :, :, -current_decode_mask_len:
|
||||
]
|
||||
next_decode_attn_mask = F.pad(current_decode_attn_mask, (0, 1), value=False)
|
||||
next_decode_attn_mask = _fit_decode_mask_length(next_decode_attn_mask, real_next_kv_len + 1)
|
||||
if not next_decode_attn_mask.any().item():
|
||||
next_decode_attn_mask = None
|
||||
next_running.append(
|
||||
T2SRunningRequest(
|
||||
state=running_request.state,
|
||||
y_sequence=new_history,
|
||||
prefix_len=running_request.prefix_len,
|
||||
decode_attn_mask=next_decode_attn_mask,
|
||||
k_cache=request_k_cache,
|
||||
v_cache=request_v_cache,
|
||||
step_idx=current_idx + 1,
|
||||
)
|
||||
)
|
||||
|
||||
return next_running, finished_items
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_scheduler_continuous(
|
||||
model: Any,
|
||||
states: Sequence[T2SRequestState],
|
||||
max_steps: int,
|
||||
) -> List[T2SFinishedItem]:
|
||||
pending = sorted(states, key=lambda item: (item.ready_step, item.request_id))
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
finished: List[T2SFinishedItem] = []
|
||||
current_tick = 0
|
||||
|
||||
while pending or running_requests:
|
||||
admitted: List[T2SRequestState] = []
|
||||
while pending and pending[0].ready_step <= current_tick:
|
||||
admitted.append(pending.pop(0))
|
||||
|
||||
admitted_running, admitted_finished = run_prefill_step(model, admitted, max_steps=max_steps)
|
||||
finished.extend(admitted_finished)
|
||||
|
||||
if running_requests:
|
||||
running_requests, step_finished = run_decode_step_for_running(
|
||||
model,
|
||||
running_requests,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
finished.extend(step_finished)
|
||||
|
||||
running_requests.extend(admitted_running)
|
||||
|
||||
if not running_requests and pending:
|
||||
current_tick = max(current_tick + 1, pending[0].ready_step)
|
||||
continue
|
||||
|
||||
current_tick += 1
|
||||
|
||||
finished.sort(key=lambda item: item.request_id)
|
||||
return finished
|
||||
@ -180,10 +180,15 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
|
||||
def _g2p(segments):
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
for seg in segments:
|
||||
g2pw_batch_results = []
|
||||
g2pw_batch_cursor = 0
|
||||
processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments]
|
||||
if is_g2pw:
|
||||
batch_inputs = [seg for seg in processed_segments if seg]
|
||||
g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else []
|
||||
|
||||
for seg in processed_segments:
|
||||
pinyins = []
|
||||
# Replace all English words in the sentence
|
||||
seg = re.sub("[a-zA-Z]+", "", seg)
|
||||
seg_cut = psg.lcut(seg)
|
||||
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
||||
initials = []
|
||||
@ -204,8 +209,10 @@ def _g2p(segments):
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果", initials, finals)
|
||||
else:
|
||||
# g2pw采用整句推理
|
||||
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
# g2pw采用整句推理(批量推理,逐句取结果)
|
||||
if seg:
|
||||
pinyins = g2pw_batch_results[g2pw_batch_cursor]
|
||||
g2pw_batch_cursor += 1
|
||||
|
||||
pre_word_length = 0
|
||||
for word, pos in seg_cut:
|
||||
|
||||
@ -18,6 +18,7 @@ Credits
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -37,6 +38,8 @@ def prepare_onnx_input(
|
||||
use_mask: bool = False,
|
||||
window_size: int = None,
|
||||
max_len: int = 512,
|
||||
char2id: Optional[Dict[str, int]] = None,
|
||||
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
|
||||
) -> Dict[str, np.array]:
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||
@ -48,33 +51,88 @@ def prepare_onnx_input(
|
||||
phoneme_masks = []
|
||||
char_ids = []
|
||||
position_ids = []
|
||||
tokenized_cache = {}
|
||||
|
||||
if char2id is None:
|
||||
char2id = {char: idx for idx, char in enumerate(chars)}
|
||||
if use_mask:
|
||||
if char_phoneme_masks is None:
|
||||
char_phoneme_masks = {
|
||||
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
|
||||
for char in char2phonemes
|
||||
}
|
||||
else:
|
||||
full_phoneme_mask = [1] * len(labels)
|
||||
|
||||
for idx in range(len(texts)):
|
||||
text = (truncated_texts if window_size else texts)[idx].lower()
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
cached = tokenized_cache.get(text)
|
||||
if cached is None:
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(
|
||||
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
|
||||
)
|
||||
if len(tokens) <= max_len - 2:
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
cached = {
|
||||
"is_short": True,
|
||||
"tokens": tokens,
|
||||
"text2token": text2token,
|
||||
"token2text": token2text,
|
||||
"input_id": shared_input_id,
|
||||
"token_type_id": shared_token_type_id,
|
||||
"attention_mask": shared_attention_mask,
|
||||
}
|
||||
else:
|
||||
cached = {
|
||||
"is_short": False,
|
||||
"tokens": tokens,
|
||||
"text2token": text2token,
|
||||
"token2text": token2text,
|
||||
}
|
||||
tokenized_cache[text] = cached
|
||||
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
if cached["is_short"]:
|
||||
text_for_query = text
|
||||
query_id_for_query = query_id
|
||||
text2token_for_query = cached["text2token"]
|
||||
input_id = cached["input_id"]
|
||||
token_type_id = cached["token_type_id"]
|
||||
attention_mask = cached["attention_mask"]
|
||||
else:
|
||||
(
|
||||
text_for_query,
|
||||
query_id_for_query,
|
||||
tokens_for_query,
|
||||
text2token_for_query,
|
||||
_token2text_for_query,
|
||||
) = _truncate(
|
||||
max_len=max_len,
|
||||
text=text,
|
||||
query_id=query_id,
|
||||
tokens=cached["tokens"],
|
||||
text2token=cached["text2token"],
|
||||
token2text=cached["token2text"],
|
||||
)
|
||||
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = (
|
||||
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
|
||||
)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||
query_char = text_for_query[query_id_for_query]
|
||||
if use_mask:
|
||||
phoneme_mask = char_phoneme_masks[query_char]
|
||||
else:
|
||||
phoneme_mask = full_phoneme_mask
|
||||
char_id = char2id[query_char]
|
||||
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
@ -83,10 +141,15 @@ def prepare_onnx_input(
|
||||
char_ids.append(char_id)
|
||||
position_ids.append(position_id)
|
||||
|
||||
max_token_length = max(len(seq) for seq in input_ids)
|
||||
|
||||
def _pad_sequences(sequences, pad_value=0):
|
||||
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
|
||||
|
||||
outputs = {
|
||||
"input_ids": np.array(input_ids).astype(np.int64),
|
||||
"token_type_ids": np.array(token_type_ids).astype(np.int64),
|
||||
"attention_masks": np.array(attention_masks).astype(np.int64),
|
||||
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
|
||||
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
|
||||
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
|
||||
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
||||
"char_ids": np.array(char_ids).astype(np.int64),
|
||||
"position_ids": np.array(position_ids).astype(np.int64),
|
||||
|
||||
@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import requests
|
||||
import torch
|
||||
from opencc import OpenCC
|
||||
from pypinyin import Style, pinyin
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
@ -22,9 +21,8 @@ from .utils import load_config
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
try:
|
||||
onnxruntime.preload_dlls()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
# traceback.print_exc()
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
model_version = "1.1"
|
||||
@ -55,6 +53,24 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]:
|
||||
for candidate_dir in candidate_dirs:
|
||||
if not candidate_dir:
|
||||
continue
|
||||
json_path = os.path.join(candidate_dir, filename)
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, "r", encoding="utf-8") as fr:
|
||||
return json.load(fr)
|
||||
raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}")
|
||||
|
||||
|
||||
def _find_first_existing_file(*paths: str) -> str:
|
||||
for path in paths:
|
||||
if path and os.path.exists(path):
|
||||
return path
|
||||
raise FileNotFoundError(f"Files not found: {paths}")
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
@ -62,7 +78,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
|
||||
print("Downloading g2pw model...")
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"
|
||||
with requests.get(modelscope_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(zip_dir, "wb") as f:
|
||||
@ -79,7 +95,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
return model_dir
|
||||
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
class _G2PWBaseOnnxConverter:
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
@ -87,33 +103,16 @@ class G2PWOnnxConverter:
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
uncompress_path = download_and_decompress(model_dir)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
|
||||
self.model_dir = download_and_decompress(model_dir)
|
||||
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
|
||||
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
|
||||
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
|
||||
|
||||
self.polyphonic_chars = [
|
||||
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
@ -149,31 +148,47 @@ class G2PWOnnxConverter:
|
||||
)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
self.char2id = {char: idx for idx, char in enumerate(self.chars)}
|
||||
self.char_phoneme_masks = (
|
||||
{
|
||||
char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))]
|
||||
for char in self.char2phonemes
|
||||
}
|
||||
if self.config.use_mask
|
||||
else None
|
||||
)
|
||||
|
||||
self.polyphonic_chars_new = set(self.chars)
|
||||
for char in self.non_polyphonic:
|
||||
if char in self.polyphonic_chars_new:
|
||||
self.polyphonic_chars_new.remove(char)
|
||||
self.polyphonic_chars_new.discard(char)
|
||||
|
||||
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
|
||||
for char in self.non_monophonic:
|
||||
if char in self.monophonic_chars_dict:
|
||||
self.monophonic_chars_dict.pop(char)
|
||||
self.monophonic_chars_dict.pop(char, None)
|
||||
|
||||
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
|
||||
default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel"))
|
||||
candidate_asset_dirs = [self.model_dir, default_asset_dir]
|
||||
self.bopomofo_convert_dict = _load_json_from_candidates(
|
||||
"bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs
|
||||
)
|
||||
self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs)
|
||||
|
||||
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
"bopomofo": lambda x: x,
|
||||
"pinyin": self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC("s2tw")
|
||||
self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"y",
|
||||
"on",
|
||||
}
|
||||
# 聚焦到多音字附近上下文,默认左右各16字;设为0表示关闭裁剪(整句)。
|
||||
self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16")))
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||
tone = bopomofo[-1]
|
||||
@ -181,9 +196,8 @@ class G2PWOnnxConverter:
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
else:
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
|
||||
def __call__(self, sentences: List[str]) -> List[List[str]]:
|
||||
if isinstance(sentences, str):
|
||||
@ -197,51 +211,147 @@ class G2PWOnnxConverter:
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
|
||||
onnx_input = prepare_onnx_input(
|
||||
model_input = prepare_onnx_input(
|
||||
tokenizer=self.tokenizer,
|
||||
labels=self.labels,
|
||||
char2phonemes=self.char2phonemes,
|
||||
chars=self.chars,
|
||||
texts=texts,
|
||||
query_ids=query_ids,
|
||||
query_ids=model_query_ids,
|
||||
use_mask=self.config.use_mask,
|
||||
window_size=None,
|
||||
char2id=self.char2id,
|
||||
char_phoneme_masks=self.char_phoneme_masks,
|
||||
)
|
||||
|
||||
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
|
||||
if not model_input:
|
||||
return partial_results
|
||||
|
||||
if self.enable_sentence_dedup:
|
||||
preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts)
|
||||
else:
|
||||
preds, _confidences = self._predict(model_input=model_input)
|
||||
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(" ")[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
|
||||
results[sent_id][query_id] = self.style_convert_func(pred)
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]:
|
||||
texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||
sent_s = tranditional_to_simplified(sent)
|
||||
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
polyphonic_indices: List[int] = []
|
||||
for i, char in enumerate(sent):
|
||||
if char in self.polyphonic_chars_new:
|
||||
texts.append(sent)
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
polyphonic_indices.append(i)
|
||||
elif char in self.monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
else:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
|
||||
if polyphonic_indices:
|
||||
if self.polyphonic_context_chars > 0:
|
||||
left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars)
|
||||
right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1)
|
||||
sent_for_predict = sent[left:right]
|
||||
query_offset = left
|
||||
else:
|
||||
sent_for_predict = sent
|
||||
query_offset = 0
|
||||
|
||||
for index in polyphonic_indices:
|
||||
texts.append(sent_for_predict)
|
||||
model_query_ids.append(index - query_offset)
|
||||
result_query_ids.append(index)
|
||||
sent_ids.append(sent_id)
|
||||
|
||||
partial_results.append(partial_result)
|
||||
return texts, query_ids, sent_ids, partial_results
|
||||
return texts, model_query_ids, result_query_ids, sent_ids, partial_results
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _predict_with_sentence_dedup(
|
||||
self, model_input: Dict[str, Any], texts: List[str]
|
||||
) -> Tuple[List[str], List[float]]:
|
||||
if len(texts) <= 1:
|
||||
return self._predict(model_input=model_input)
|
||||
|
||||
grouped_indices: Dict[str, List[int]] = {}
|
||||
for idx, text in enumerate(texts):
|
||||
grouped_indices.setdefault(text, []).append(idx)
|
||||
|
||||
if all(len(indices) == 1 for indices in grouped_indices.values()):
|
||||
return self._predict(model_input=model_input)
|
||||
|
||||
preds: List[str] = [""] * len(texts)
|
||||
confidences: List[float] = [0.0] * len(texts)
|
||||
for indices in grouped_indices.values():
|
||||
group_input = {name: value[indices] for name, value in model_input.items()}
|
||||
if len(indices) > 1:
|
||||
for name in ("input_ids", "token_type_ids", "attention_masks"):
|
||||
group_input[name] = group_input[name][:1]
|
||||
|
||||
group_preds, group_confidences = self._predict(model_input=group_input)
|
||||
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
|
||||
preds[output_idx] = pred
|
||||
confidences[output_idx] = confidence
|
||||
|
||||
return preds, confidences
|
||||
|
||||
|
||||
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
style=style,
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
|
||||
onnx_path = _find_first_existing_file(
|
||||
os.path.join(self.model_dir, "g2pW.onnx"),
|
||||
os.path.join(self.model_dir, "g2pw.onnx"),
|
||||
)
|
||||
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
self.session_g2pw = onnxruntime.InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
self.session_g2pw = onnxruntime.InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels)
|
||||
|
||||
250
tools/bench_api_v3_scheduler_submit.py
Normal file
250
tools/bench_api_v3_scheduler_submit.py
Normal file
@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark api_v3 /tts_scheduler_submit concurrency and GPU memory.")
|
||||
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:9880")
|
||||
parser.add_argument("--endpoint", type=str, default="/tts_scheduler_submit")
|
||||
parser.add_argument("--concurrency", type=int, required=True)
|
||||
parser.add_argument("--timeout-sec", type=float, default=120.0)
|
||||
parser.add_argument("--server-pid", type=int, default=None)
|
||||
parser.add_argument("--poll-interval-sec", type=float, default=0.1)
|
||||
parser.add_argument("--text-lang", type=str, default="zh")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--media-type", type=str, default="wav")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--sample-steps", type=int, default=32)
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
||||
parser.add_argument("--wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
||||
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/api_v3_bench")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_requests(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
||||
wav_paths_all = sorted(args.wav_dir.glob("*.wav"))
|
||||
wav_paths: List[Path] = []
|
||||
for wav_path in wav_paths_all:
|
||||
with wave.open(str(wav_path), "rb") as handle:
|
||||
duration = handle.getnframes() / float(handle.getframerate())
|
||||
if 3.0 <= duration <= 10.0:
|
||||
wav_paths.append(wav_path)
|
||||
if not wav_paths:
|
||||
raise FileNotFoundError(f"没有找到 3-10 秒合法 wav: {args.wav_dir}")
|
||||
text_lines = [line.strip() for line in args.text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||
if not text_lines:
|
||||
raise ValueError(f"没有找到有效文本行: {args.text_file}")
|
||||
|
||||
requests: List[Dict[str, Any]] = []
|
||||
for index in range(args.concurrency):
|
||||
wav_path = wav_paths[index % len(wav_paths)]
|
||||
lab_path = wav_path.with_suffix(".lab")
|
||||
if not lab_path.exists():
|
||||
raise FileNotFoundError(f"缺少参考文本: {lab_path}")
|
||||
requests.append(
|
||||
{
|
||||
"request_id": f"bench_{args.concurrency:03d}_{index:03d}",
|
||||
"text": text_lines[index % len(text_lines)],
|
||||
"text_lang": args.text_lang,
|
||||
"ref_audio_path": str(wav_path),
|
||||
"prompt_lang": args.prompt_lang,
|
||||
"prompt_text": lab_path.read_text(encoding="utf-8").strip(),
|
||||
"top_k": int(args.top_k),
|
||||
"top_p": float(args.top_p),
|
||||
"temperature": float(args.temperature),
|
||||
"repetition_penalty": float(args.repetition_penalty),
|
||||
"sample_steps": int(args.sample_steps),
|
||||
"media_type": args.media_type,
|
||||
"timeout_sec": float(args.timeout_sec),
|
||||
}
|
||||
)
|
||||
return requests
|
||||
|
||||
|
||||
class GpuMemoryPoller:
|
||||
def __init__(self, server_pid: Optional[int], interval_sec: float):
|
||||
self.server_pid = server_pid
|
||||
self.interval_sec = interval_sec
|
||||
self._stop = threading.Event()
|
||||
self.samples: List[Dict[str, Any]] = []
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
|
||||
def _query_memory_mb(self) -> Optional[int]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-compute-apps=pid,used_gpu_memory",
|
||||
"--format=csv,noheader,nounits",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
total = 0
|
||||
found = False
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = [item.strip() for item in line.split(",")]
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
used_mb = int(parts[1])
|
||||
except ValueError:
|
||||
continue
|
||||
if self.server_pid is None or pid == self.server_pid:
|
||||
total += used_mb
|
||||
found = True
|
||||
if self.server_pid is None:
|
||||
return total
|
||||
return total if found else 0
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
used_mb = self._query_memory_mb()
|
||||
self.samples.append({"ts": time.time(), "used_mb": used_mb})
|
||||
self._stop.wait(self.interval_sec)
|
||||
|
||||
def start(self) -> None:
|
||||
self.thread = threading.Thread(target=self._run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
if self.thread is not None:
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
valid = [item for item in self.samples if item["used_mb"] is not None]
|
||||
peak = max(valid, key=lambda item: item["used_mb"]) if valid else None
|
||||
first = valid[0] if valid else None
|
||||
last = valid[-1] if valid else None
|
||||
return {
|
||||
"server_pid": self.server_pid,
|
||||
"sample_count": int(len(self.samples)),
|
||||
"start_used_mb": None if first is None else int(first["used_mb"]),
|
||||
"peak_used_mb": None if peak is None else int(peak["used_mb"]),
|
||||
"peak_delta_mb": None if peak is None or first is None else int(peak["used_mb"] - first["used_mb"]),
|
||||
"end_used_mb": None if last is None else int(last["used_mb"]),
|
||||
"peak_ts": None if peak is None else float(peak["ts"]),
|
||||
"samples": self.samples,
|
||||
}
|
||||
|
||||
|
||||
async def submit_one(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
response = await client.post(url, json=payload)
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
||||
item = {
|
||||
"request_id": payload["request_id"],
|
||||
"status_code": int(response.status_code),
|
||||
"elapsed_ms": float(elapsed_ms),
|
||||
"content_type": response.headers.get("content-type"),
|
||||
"audio_bytes": int(len(response.content)),
|
||||
"headers": {key: value for key, value in response.headers.items() if key.lower().startswith("x-")},
|
||||
}
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
item["error_body"] = response.json()
|
||||
except Exception:
|
||||
item["error_body"] = response.text
|
||||
return item
|
||||
except Exception as exc:
|
||||
return {
|
||||
"request_id": payload["request_id"],
|
||||
"status_code": -1,
|
||||
"elapsed_ms": float((time.perf_counter() - started) * 1000.0),
|
||||
"exception": repr(exc),
|
||||
}
|
||||
|
||||
|
||||
async def run_benchmark(args: argparse.Namespace) -> Dict[str, Any]:
|
||||
payloads = load_requests(args)
|
||||
url = args.base_url.rstrip("/") + args.endpoint
|
||||
poller = GpuMemoryPoller(server_pid=args.server_pid, interval_sec=args.poll_interval_sec)
|
||||
|
||||
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
|
||||
timeout = httpx.Timeout(connect=10.0, read=args.timeout_sec + 10.0, write=10.0, pool=10.0)
|
||||
|
||||
started = time.perf_counter()
|
||||
poller.start()
|
||||
try:
|
||||
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
|
||||
results = await asyncio.gather(*[submit_one(client, url, payload) for payload in payloads])
|
||||
finally:
|
||||
poller.stop()
|
||||
wall_ms = (time.perf_counter() - started) * 1000.0
|
||||
|
||||
ok_results = [item for item in results if item["status_code"] == 200]
|
||||
failed_results = [item for item in results if item["status_code"] != 200]
|
||||
request_total_ms = []
|
||||
worker_total_ms = []
|
||||
for item in ok_results:
|
||||
headers = item.get("headers", {})
|
||||
if "x-request-total-ms" in headers:
|
||||
request_total_ms.append(float(headers["x-request-total-ms"]))
|
||||
if "x-worker-total-ms" in headers:
|
||||
worker_total_ms.append(float(headers["x-worker-total-ms"]))
|
||||
|
||||
return {
|
||||
"concurrency": int(args.concurrency),
|
||||
"server_pid": args.server_pid,
|
||||
"request_count": int(len(payloads)),
|
||||
"wall_ms": float(wall_ms),
|
||||
"success_count": int(len(ok_results)),
|
||||
"failure_count": int(len(failed_results)),
|
||||
"request_total_ms_avg": float(sum(request_total_ms) / len(request_total_ms)) if request_total_ms else None,
|
||||
"request_total_ms_max": float(max(request_total_ms)) if request_total_ms else None,
|
||||
"worker_total_ms_avg": float(sum(worker_total_ms) / len(worker_total_ms)) if worker_total_ms else None,
|
||||
"worker_total_ms_max": float(max(worker_total_ms)) if worker_total_ms else None,
|
||||
"gpu_memory": poller.summary(),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
output_dir = args.output_dir / f"concurrency_{args.concurrency:02d}"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
summary = asyncio.run(run_benchmark(args))
|
||||
summary_path = output_dir / "summary.json"
|
||||
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps({
|
||||
"concurrency": summary["concurrency"],
|
||||
"success_count": summary["success_count"],
|
||||
"failure_count": summary["failure_count"],
|
||||
"wall_ms": summary["wall_ms"],
|
||||
"gpu_peak_used_mb": summary["gpu_memory"]["peak_used_mb"],
|
||||
"request_total_ms_avg": summary["request_total_ms_avg"],
|
||||
"request_total_ms_max": summary["request_total_ms_max"],
|
||||
"summary_path": str(summary_path),
|
||||
}, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
887
tools/t2s_memory_breakdown.py
Normal file
887
tools/t2s_memory_breakdown.py
Normal file
@ -0,0 +1,887 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
|
||||
if str(gpt_sovits_dir) not in sys.path:
|
||||
sys.path.append(str(gpt_sovits_dir))
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config # noqa: E402
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
|
||||
SchedulerRequestSpec,
|
||||
T2SRequestState,
|
||||
T2SRunningRequest,
|
||||
_build_decode_batch_from_running,
|
||||
build_prefill_batch,
|
||||
prepare_request_state,
|
||||
run_decode_step_for_running,
|
||||
run_prefill_step,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Break down T2S CUDA memory by stage and tensor groups.")
|
||||
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
|
||||
parser.add_argument("--request-manifest", type=Path, default=None)
|
||||
parser.add_argument("--scenario", type=str, default="auto4", choices=["auto4", "single"])
|
||||
parser.add_argument("--auto-count", type=int, default=4)
|
||||
parser.add_argument("--auto-wav-dir", type=Path, default=ROOT_DIR / "testwav")
|
||||
parser.add_argument("--auto-text-file", type=Path, default=ROOT_DIR / "test_cn.txt")
|
||||
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
|
||||
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--text", type=str, default=None)
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
|
||||
parser.add_argument("--text-lang", type=str, default="zh")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--early-stop-num", type=int, default=-1)
|
||||
parser.add_argument("--max-steps", type=int, default=1500)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument("--warmup", action="store_true", default=False)
|
||||
parser.add_argument("--worker-rounds", type=int, default=1)
|
||||
parser.add_argument("--worker-grad-mode", type=str, default="default", choices=["default", "inference_mode"])
|
||||
parser.add_argument("--compare-worker-grad-modes", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=ROOT_DIR / "TEMP/t2s_memory_breakdown/run1",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def set_seed(seed: int, use_cuda: bool) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _sync_device(device: Any) -> None:
|
||||
try:
|
||||
device_str = str(device)
|
||||
if device_str.startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize(device)
|
||||
elif device_str == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"):
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def bytes_to_mb(num_bytes: int) -> float:
|
||||
return float(num_bytes) / (1024.0 * 1024.0)
|
||||
|
||||
|
||||
def tensor_nbytes(tensor: Optional[torch.Tensor]) -> int:
|
||||
if tensor is None:
|
||||
return 0
|
||||
return int(tensor.numel() * tensor.element_size())
|
||||
|
||||
|
||||
def tensor_list_nbytes(items: Sequence[torch.Tensor]) -> int:
|
||||
return int(sum(tensor_nbytes(item) for item in items))
|
||||
|
||||
|
||||
def model_nbytes(module: torch.nn.Module) -> int:
|
||||
total = 0
|
||||
for parameter in module.parameters():
|
||||
total += tensor_nbytes(parameter)
|
||||
for buffer in module.buffers():
|
||||
total += tensor_nbytes(buffer)
|
||||
return int(total)
|
||||
|
||||
|
||||
def build_module_weight_summary(tts: TTS) -> Dict[str, Any]:
|
||||
modules = {
|
||||
"t2s_model": tts.t2s_model,
|
||||
"t2s_core": tts.t2s_model.model if tts.t2s_model is not None else None,
|
||||
"vits_model": tts.vits_model,
|
||||
"bert_model": tts.bert_model,
|
||||
"cnhuhbert_model": tts.cnhuhbert_model,
|
||||
"vocoder": tts.vocoder,
|
||||
"sv_model": tts.sv_model,
|
||||
}
|
||||
by_module = {}
|
||||
total_bytes = 0
|
||||
for name, module in modules.items():
|
||||
module_bytes = model_nbytes(module) if module is not None else 0
|
||||
by_module[name] = {
|
||||
"bytes": int(module_bytes),
|
||||
"mb": bytes_to_mb(module_bytes),
|
||||
}
|
||||
total_bytes += module_bytes
|
||||
return {
|
||||
"by_module": by_module,
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
}
|
||||
|
||||
|
||||
def snapshot_live_cuda_tensors(top_k: int = 40) -> Dict[str, Any]:
|
||||
storages: Dict[int, Dict[str, Any]] = {}
|
||||
tensor_views: List[Dict[str, Any]] = []
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
tensor = None
|
||||
if torch.is_tensor(obj):
|
||||
tensor = obj
|
||||
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
|
||||
tensor = obj.data
|
||||
if tensor is None or not tensor.is_cuda:
|
||||
continue
|
||||
storage = tensor.untyped_storage()
|
||||
storage_ptr = int(storage.data_ptr())
|
||||
if storage_ptr not in storages:
|
||||
storages[storage_ptr] = {
|
||||
"storage_ptr": storage_ptr,
|
||||
"storage_bytes": int(storage.nbytes()),
|
||||
"dtype": str(tensor.dtype),
|
||||
"shape": list(tensor.shape),
|
||||
"device": str(tensor.device),
|
||||
}
|
||||
tensor_views.append(
|
||||
{
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
"bytes": tensor_nbytes(tensor),
|
||||
"device": str(tensor.device),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
storage_list = sorted(storages.values(), key=lambda item: item["storage_bytes"], reverse=True)
|
||||
tensor_views.sort(key=lambda item: item["bytes"], reverse=True)
|
||||
return {
|
||||
"unique_storage_count": int(len(storage_list)),
|
||||
"unique_storage_total_bytes": int(sum(item["storage_bytes"] for item in storage_list)),
|
||||
"unique_storage_total_mb": bytes_to_mb(sum(item["storage_bytes"] for item in storage_list)),
|
||||
"top_storages": storage_list[:top_k],
|
||||
"top_tensor_views": tensor_views[:top_k],
|
||||
}
|
||||
|
||||
|
||||
def build_single_spec(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8").strip()
|
||||
return [
|
||||
SchedulerRequestSpec(
|
||||
request_id="req_000",
|
||||
ref_audio_path=args.ref_audio,
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_lang=args.prompt_lang,
|
||||
text=text,
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def build_auto_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
wav_paths = sorted(args.auto_wav_dir.glob("*.wav"))[: args.auto_count]
|
||||
if len(wav_paths) < args.auto_count:
|
||||
raise ValueError(f"auto wav count不足,目录 {args.auto_wav_dir} 只有 {len(wav_paths)} 条 wav")
|
||||
text_lines = [line.strip() for line in args.auto_text_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||
if len(text_lines) < args.auto_count:
|
||||
raise ValueError(f"auto text lines不足,文件 {args.auto_text_file} 只有 {len(text_lines)} 行有效文本")
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, wav_path in enumerate(wav_paths):
|
||||
lab_path = wav_path.with_suffix(".lab")
|
||||
if not lab_path.exists():
|
||||
raise FileNotFoundError(f"找不到参考文本 {lab_path}")
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=f"req_{index:03d}",
|
||||
ref_audio_path=wav_path,
|
||||
prompt_text=lab_path.read_text(encoding="utf-8").strip(),
|
||||
prompt_lang="zh",
|
||||
text=text_lines[index],
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
if args.request_manifest is not None:
|
||||
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
|
||||
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, item in enumerate(raw_requests):
|
||||
text = item.get("text")
|
||||
text_file = item.get("text_file")
|
||||
if text is None and text_file is None:
|
||||
raise ValueError(f"request[{index}] must provide text or text_file")
|
||||
if text is None:
|
||||
text = Path(text_file).read_text(encoding="utf-8").strip()
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=item.get("request_id", f"req_{index:03d}"),
|
||||
ref_audio_path=Path(item["ref_audio_path"]),
|
||||
prompt_text=item["prompt_text"],
|
||||
prompt_lang=item.get("prompt_lang", "zh"),
|
||||
text=text,
|
||||
text_lang=item.get("text_lang", "zh"),
|
||||
top_k=int(item.get("top_k", args.top_k)),
|
||||
top_p=float(item.get("top_p", args.top_p)),
|
||||
temperature=float(item.get("temperature", args.temperature)),
|
||||
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
|
||||
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
|
||||
ready_step=int(item.get("ready_step", 0)),
|
||||
)
|
||||
)
|
||||
return specs
|
||||
if args.scenario == "single":
|
||||
return build_single_spec(args)
|
||||
return build_auto_specs(args)
|
||||
|
||||
|
||||
def load_pipeline(config_path: Path) -> TTS:
|
||||
tts_config = TTS_Config(str(config_path))
|
||||
print(tts_config)
|
||||
return TTS(tts_config)
|
||||
|
||||
|
||||
def cuda_mem_snapshot(device: Any) -> Dict[str, float]:
|
||||
if not (str(device).startswith("cuda") and torch.cuda.is_available()):
|
||||
return {
|
||||
"allocated_mb": 0.0,
|
||||
"reserved_mb": 0.0,
|
||||
"max_allocated_mb": 0.0,
|
||||
"max_reserved_mb": 0.0,
|
||||
}
|
||||
_sync_device(device)
|
||||
return {
|
||||
"allocated_mb": bytes_to_mb(torch.cuda.memory_allocated(device)),
|
||||
"reserved_mb": bytes_to_mb(torch.cuda.memory_reserved(device)),
|
||||
"max_allocated_mb": bytes_to_mb(torch.cuda.max_memory_allocated(device)),
|
||||
"max_reserved_mb": bytes_to_mb(torch.cuda.max_memory_reserved(device)),
|
||||
}
|
||||
|
||||
|
||||
def stage_run(device: Any, fn) -> Tuple[Any, Dict[str, float]]:
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
gc.collect()
|
||||
_sync_device(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
before = cuda_mem_snapshot(device)
|
||||
started = time.perf_counter()
|
||||
result = fn()
|
||||
_sync_device(device)
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000.0
|
||||
after = cuda_mem_snapshot(device)
|
||||
after["elapsed_ms"] = float(elapsed_ms)
|
||||
after["delta_allocated_mb"] = float(after["allocated_mb"] - before["allocated_mb"])
|
||||
after["delta_reserved_mb"] = float(after["reserved_mb"] - before["reserved_mb"])
|
||||
after["stage_peak_over_before_mb"] = float(max(after["max_allocated_mb"] - before["allocated_mb"], 0.0))
|
||||
return result, after
|
||||
|
||||
|
||||
class GlobalPeakRecorder:
|
||||
def __init__(self, device: Any):
|
||||
self.device = device
|
||||
self.checkpoints: List[Dict[str, Any]] = []
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
gc.collect()
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
def record(self, label: str, **extra: Any) -> None:
|
||||
snapshot = cuda_mem_snapshot(self.device)
|
||||
snapshot["label"] = label
|
||||
snapshot.update(extra)
|
||||
self.checkpoints.append(snapshot)
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
peak = max(self.checkpoints, key=lambda item: item["max_allocated_mb"]) if self.checkpoints else None
|
||||
return {
|
||||
"peak_allocated_mb": 0.0 if peak is None else float(peak["max_allocated_mb"]),
|
||||
"peak_reserved_mb": 0.0 if peak is None else float(peak["max_reserved_mb"]),
|
||||
"peak_label": None if peak is None else peak["label"],
|
||||
"checkpoints": self.checkpoints,
|
||||
}
|
||||
|
||||
|
||||
def summarise_state_tensors(states: Sequence[T2SRequestState]) -> Dict[str, Any]:
|
||||
per_request = []
|
||||
total = {
|
||||
"phones_bytes": 0,
|
||||
"prompt_phones_bytes": 0,
|
||||
"all_phones_bytes": 0,
|
||||
"all_bert_features_bytes": 0,
|
||||
"prompt_semantic_bytes": 0,
|
||||
"refer_spec_bytes": 0,
|
||||
"raw_audio_bytes": 0,
|
||||
"audio_16k_bytes": 0,
|
||||
}
|
||||
for state in states:
|
||||
spec_audio, audio_16k = state.refer_spec
|
||||
item = {
|
||||
"request_id": state.request_id,
|
||||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||||
"phones_len": int(state.phones.shape[0]),
|
||||
"all_phones_len": int(state.all_phones.shape[0]),
|
||||
"bert_frames": int(state.all_bert_features.shape[-1]),
|
||||
"phones_bytes": tensor_nbytes(state.phones),
|
||||
"prompt_phones_bytes": tensor_nbytes(state.prompt_phones),
|
||||
"all_phones_bytes": tensor_nbytes(state.all_phones),
|
||||
"all_bert_features_bytes": tensor_nbytes(state.all_bert_features),
|
||||
"prompt_semantic_bytes": tensor_nbytes(state.prompt_semantic),
|
||||
"refer_spec_bytes": tensor_nbytes(spec_audio),
|
||||
"audio_16k_bytes": tensor_nbytes(audio_16k),
|
||||
"raw_audio_bytes": tensor_nbytes(state.raw_audio),
|
||||
}
|
||||
for key in total:
|
||||
total[key] += int(item[key])
|
||||
per_request.append(item)
|
||||
total["total_bytes"] = int(sum(total.values()))
|
||||
total["total_mb"] = bytes_to_mb(total["total_bytes"])
|
||||
return {"per_request": per_request, "total": total}
|
||||
|
||||
|
||||
def summarise_prefill_batch(active_batch: Any) -> Dict[str, Any]:
|
||||
y_sequence_bytes = int(sum(tensor_nbytes(item) for item in active_batch.y_sequences))
|
||||
fields = {
|
||||
"x_bytes": tensor_nbytes(active_batch.x),
|
||||
"x_lens_bytes": tensor_nbytes(active_batch.x_lens),
|
||||
"prefix_lens_bytes": tensor_nbytes(active_batch.prefix_lens),
|
||||
"xy_pos_bytes": tensor_nbytes(active_batch.xy_pos),
|
||||
"key_padding_mask_bytes": tensor_nbytes(active_batch.key_padding_mask),
|
||||
"prefill_attn_mask_bytes": tensor_nbytes(active_batch.prefill_attn_mask),
|
||||
"y_sequence_bytes": y_sequence_bytes,
|
||||
}
|
||||
fields["total_bytes"] = int(sum(fields.values()))
|
||||
fields["total_mb"] = bytes_to_mb(fields["total_bytes"])
|
||||
fields["batch_size"] = int(len(active_batch.states))
|
||||
fields["max_x_len"] = int(active_batch.x.shape[1])
|
||||
fields["src_len"] = int(active_batch.xy_pos.shape[1])
|
||||
fields["prefill_attn_mask_shape"] = list(active_batch.prefill_attn_mask.shape)
|
||||
return fields
|
||||
|
||||
|
||||
def summarise_running_requests(running_requests: Sequence[T2SRunningRequest]) -> Dict[str, Any]:
|
||||
per_request = []
|
||||
total_private_k_bytes = 0
|
||||
total_private_v_bytes = 0
|
||||
total_decode_mask_bytes = 0
|
||||
total_y_sequence_bytes = 0
|
||||
for item in running_requests:
|
||||
k_bytes = tensor_list_nbytes(item.k_cache)
|
||||
v_bytes = tensor_list_nbytes(item.v_cache)
|
||||
mask_bytes = tensor_nbytes(item.decode_attn_mask)
|
||||
y_bytes = tensor_nbytes(item.y_sequence)
|
||||
total_private_k_bytes += k_bytes
|
||||
total_private_v_bytes += v_bytes
|
||||
total_decode_mask_bytes += mask_bytes
|
||||
total_y_sequence_bytes += y_bytes
|
||||
per_request.append(
|
||||
{
|
||||
"request_id": item.state.request_id,
|
||||
"step_idx": int(item.step_idx),
|
||||
"prefix_len": int(item.prefix_len),
|
||||
"history_len": int(item.y_sequence.shape[0]),
|
||||
"kv_len": int(item.k_cache[0].shape[1]),
|
||||
"k_cache_bytes": k_bytes,
|
||||
"v_cache_bytes": v_bytes,
|
||||
"decode_mask_bytes": mask_bytes,
|
||||
"y_sequence_bytes": y_bytes,
|
||||
}
|
||||
)
|
||||
total_bytes = total_private_k_bytes + total_private_v_bytes + total_decode_mask_bytes + total_y_sequence_bytes
|
||||
return {
|
||||
"per_request": per_request,
|
||||
"totals": {
|
||||
"private_k_cache_bytes": int(total_private_k_bytes),
|
||||
"private_v_cache_bytes": int(total_private_v_bytes),
|
||||
"private_kv_cache_bytes": int(total_private_k_bytes + total_private_v_bytes),
|
||||
"decode_mask_bytes": int(total_decode_mask_bytes),
|
||||
"y_sequence_bytes": int(total_y_sequence_bytes),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def summarise_decode_batch(
|
||||
xy_pos: torch.Tensor,
|
||||
batched_k_cache: Sequence[torch.Tensor],
|
||||
batched_v_cache: Sequence[torch.Tensor],
|
||||
batched_decode_attn_mask: Optional[torch.Tensor],
|
||||
running_requests: Sequence[T2SRunningRequest],
|
||||
) -> Dict[str, Any]:
|
||||
private_k_bytes = int(sum(tensor_list_nbytes(item.k_cache) for item in running_requests))
|
||||
private_v_bytes = int(sum(tensor_list_nbytes(item.v_cache) for item in running_requests))
|
||||
batched_k_bytes = tensor_list_nbytes(batched_k_cache)
|
||||
batched_v_bytes = tensor_list_nbytes(batched_v_cache)
|
||||
batched_mask_bytes = tensor_nbytes(batched_decode_attn_mask)
|
||||
xy_pos_bytes = tensor_nbytes(xy_pos)
|
||||
total_bytes = batched_k_bytes + batched_v_bytes + batched_mask_bytes + xy_pos_bytes
|
||||
return {
|
||||
"batch_size": int(len(running_requests)),
|
||||
"xy_pos_bytes": int(xy_pos_bytes),
|
||||
"batched_k_cache_bytes": int(batched_k_bytes),
|
||||
"batched_v_cache_bytes": int(batched_v_bytes),
|
||||
"batched_kv_cache_bytes": int(batched_k_bytes + batched_v_bytes),
|
||||
"batched_decode_mask_bytes": int(batched_mask_bytes),
|
||||
"private_kv_cache_bytes_reference": int(private_k_bytes + private_v_bytes),
|
||||
"kv_padding_overhead_bytes": int((batched_k_bytes + batched_v_bytes) - (private_k_bytes + private_v_bytes)),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
"xy_pos_shape": list(xy_pos.shape),
|
||||
"batched_decode_mask_shape": None if batched_decode_attn_mask is None else list(batched_decode_attn_mask.shape),
|
||||
"layer_k_cache_shape": list(batched_k_cache[0].shape),
|
||||
}
|
||||
|
||||
|
||||
def summarise_decode_outputs(
|
||||
xy_dec: torch.Tensor,
|
||||
next_k_cache: Sequence[torch.Tensor],
|
||||
next_v_cache: Sequence[torch.Tensor],
|
||||
) -> Dict[str, Any]:
|
||||
xy_dec_bytes = tensor_nbytes(xy_dec)
|
||||
next_k_bytes = tensor_list_nbytes(next_k_cache)
|
||||
next_v_bytes = tensor_list_nbytes(next_v_cache)
|
||||
total_bytes = xy_dec_bytes + next_k_bytes + next_v_bytes
|
||||
return {
|
||||
"xy_dec_bytes": int(xy_dec_bytes),
|
||||
"next_k_cache_bytes": int(next_k_bytes),
|
||||
"next_v_cache_bytes": int(next_v_bytes),
|
||||
"next_kv_cache_bytes": int(next_k_bytes + next_v_bytes),
|
||||
"total_bytes": int(total_bytes),
|
||||
"total_mb": bytes_to_mb(total_bytes),
|
||||
"xy_dec_shape": list(xy_dec.shape),
|
||||
"layer_next_k_cache_shape": list(next_k_cache[0].shape),
|
||||
}
|
||||
|
||||
|
||||
def top_rankings(summary: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
ranking = [
|
||||
("request_state_total", summary["prepare_stage"]["request_state"]["total"]["total_bytes"]),
|
||||
("prefill_batch_total", summary["prefill_batch"]["tensor_bytes"]["total_bytes"]),
|
||||
("running_private_kv", summary["prefill_step"]["running_requests"]["totals"]["private_kv_cache_bytes"]),
|
||||
("decode_batched_kv", summary["decode_batch"]["tensor_bytes"]["batched_kv_cache_bytes"]),
|
||||
("decode_kv_padding_overhead", summary["decode_batch"]["tensor_bytes"]["kv_padding_overhead_bytes"]),
|
||||
("decode_outputs_next_kv", summary["decode_outputs"]["tensor_bytes"]["next_kv_cache_bytes"]),
|
||||
("prefill_attn_mask", summary["prefill_batch"]["tensor_bytes"]["prefill_attn_mask_bytes"]),
|
||||
]
|
||||
ranking.sort(key=lambda item: item[1], reverse=True)
|
||||
return [{"name": name, "bytes": int(value), "mb": bytes_to_mb(int(value))} for name, value in ranking]
|
||||
|
||||
|
||||
def synthesize_finished_item(tts: TTS, state: T2SRequestState, semantic_tokens: torch.Tensor) -> Tuple[int, np.ndarray]:
|
||||
semantic_tokens = semantic_tokens.unsqueeze(0).unsqueeze(0).to(tts.configs.device)
|
||||
phones = state.phones.unsqueeze(0).to(tts.configs.device)
|
||||
audio_fragment = tts.synthesize_audio_request_local(
|
||||
semantic_tokens=semantic_tokens,
|
||||
phones=phones,
|
||||
prompt_semantic=state.prompt_semantic,
|
||||
prompt_phones=state.prompt_phones,
|
||||
refer_spec=state.refer_spec,
|
||||
raw_audio=state.raw_audio,
|
||||
raw_sr=state.raw_sr,
|
||||
speed=1.0,
|
||||
sample_steps=32,
|
||||
)
|
||||
output_sr = tts.configs.sampling_rate if not tts.configs.use_vocoder else tts.vocoder_configs["sr"]
|
||||
return tts.audio_postprocess(
|
||||
audio=[[audio_fragment]],
|
||||
sr=int(output_sr),
|
||||
batch_index_list=None,
|
||||
speed_factor=1.0,
|
||||
split_bucket=False,
|
||||
fragment_interval=0.0,
|
||||
super_sampling=False,
|
||||
)
|
||||
|
||||
|
||||
def simulate_worker_end_to_end(
|
||||
tts: TTS,
|
||||
specs: Sequence[SchedulerRequestSpec],
|
||||
max_steps: int,
|
||||
rounds: int,
|
||||
grad_mode: str = "default",
|
||||
) -> Dict[str, Any]:
|
||||
device = tts.configs.device
|
||||
recorder = GlobalPeakRecorder(device)
|
||||
recorder.record("after_model_load")
|
||||
|
||||
state_map: Dict[str, T2SRequestState] = {}
|
||||
per_round: List[Dict[str, Any]] = []
|
||||
|
||||
for round_index in range(rounds):
|
||||
grad_context = torch.inference_mode if grad_mode == "inference_mode" else contextlib.nullcontext
|
||||
with grad_context():
|
||||
states = [prepare_request_state(tts, spec) for spec in specs]
|
||||
state_map = {state.request_id: state for state in states}
|
||||
recorder.record(
|
||||
"after_prepare_states",
|
||||
round_index=int(round_index),
|
||||
request_count=int(len(states)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
|
||||
pending = list(states)
|
||||
running_requests: List[T2SRunningRequest] = []
|
||||
round_events: List[Dict[str, Any]] = []
|
||||
current_tick = 0
|
||||
|
||||
while pending or running_requests:
|
||||
admitted = pending
|
||||
pending = []
|
||||
|
||||
if admitted:
|
||||
recorder.record(
|
||||
"before_prefill",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
admitted_count=int(len(admitted)),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
admitted_running, admitted_finished = run_prefill_step(tts.t2s_model.model, admitted, max_steps=max_steps)
|
||||
recorder.record(
|
||||
"after_prefill",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
admitted_running_count=int(len(admitted_running)),
|
||||
admitted_finished_count=int(len(admitted_finished)),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
round_events.append(
|
||||
{
|
||||
"tick": int(current_tick),
|
||||
"event": "prefill",
|
||||
"admitted_count": int(len(admitted)),
|
||||
"admitted_running_count": int(len(admitted_running)),
|
||||
"admitted_finished_count": int(len(admitted_finished)),
|
||||
}
|
||||
)
|
||||
for item in admitted_finished:
|
||||
recorder.record(
|
||||
"before_synth_prefill_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||||
recorder.record(
|
||||
"after_synth_prefill_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
sample_rate=int(sample_rate),
|
||||
audio_samples=int(audio_data.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
running_requests.extend(admitted_running)
|
||||
recorder.record(
|
||||
"after_extend_running",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
|
||||
if running_requests:
|
||||
recorder.record(
|
||||
"before_decode",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
running_requests, step_finished = run_decode_step_for_running(
|
||||
tts.t2s_model.model,
|
||||
running_requests,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
recorder.record(
|
||||
"after_decode",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_count=int(len(step_finished)),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
round_events.append(
|
||||
{
|
||||
"tick": int(current_tick),
|
||||
"event": "decode",
|
||||
"running_count_after_decode": int(len(running_requests)),
|
||||
"finished_count": int(len(step_finished)),
|
||||
}
|
||||
)
|
||||
for item in step_finished:
|
||||
recorder.record(
|
||||
"before_synth_decode_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
semantic_len=int(item.semantic_tokens.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
with grad_context():
|
||||
sample_rate, audio_data = synthesize_finished_item(tts, state_map[item.request_id], item.semantic_tokens)
|
||||
recorder.record(
|
||||
"after_synth_decode_finished",
|
||||
round_index=int(round_index),
|
||||
tick=int(current_tick),
|
||||
running_count=int(len(running_requests)),
|
||||
finished_request_id=item.request_id,
|
||||
sample_rate=int(sample_rate),
|
||||
audio_samples=int(audio_data.shape[0]),
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
current_tick += 1
|
||||
|
||||
recorder.record(
|
||||
"after_round_complete",
|
||||
round_index=int(round_index),
|
||||
running_count=0,
|
||||
grad_mode=grad_mode,
|
||||
)
|
||||
per_round.append(
|
||||
{
|
||||
"round_index": int(round_index),
|
||||
"events": round_events,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"grad_mode": grad_mode,
|
||||
"rounds": per_round,
|
||||
"timeline": recorder.summary(),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tts = load_pipeline(args.config)
|
||||
model = tts.t2s_model.model
|
||||
device = tts.configs.device
|
||||
use_cuda = str(device).startswith("cuda") and torch.cuda.is_available()
|
||||
set_seed(args.seed, use_cuda)
|
||||
|
||||
specs = load_request_specs(args)
|
||||
if args.early_stop_num == -1:
|
||||
for spec in specs:
|
||||
spec.early_stop_num = int(tts.configs.hz * tts.configs.max_sec)
|
||||
|
||||
if args.warmup and specs:
|
||||
warmup_spec = specs[:1]
|
||||
_ = [prepare_request_state(tts, spec) for spec in warmup_spec]
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
torch.cuda.empty_cache()
|
||||
_sync_device(device)
|
||||
|
||||
states, prepare_mem = stage_run(device, lambda: [prepare_request_state(tts, spec) for spec in specs])
|
||||
request_state_summary = summarise_state_tensors(states)
|
||||
|
||||
active_batch, prefill_batch_mem = stage_run(device, lambda: build_prefill_batch(model, states))
|
||||
prefill_batch_tensor_summary = summarise_prefill_batch(active_batch)
|
||||
|
||||
prefill_result, prefill_step_mem = stage_run(device, lambda: run_prefill_step(model, states, max_steps=args.max_steps))
|
||||
running_requests, finished_items = prefill_result
|
||||
running_requests_summary = summarise_running_requests(running_requests)
|
||||
finished_after_prefill_summary = [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
}
|
||||
for item in finished_items
|
||||
]
|
||||
|
||||
if not running_requests:
|
||||
raise RuntimeError(f"prefill 后没有 running requests,全部在首步结束: {[item.request_id for item in finished_items]}")
|
||||
|
||||
decode_batch_result, decode_batch_mem = stage_run(
|
||||
device,
|
||||
lambda: _build_decode_batch_from_running(model, running_requests),
|
||||
)
|
||||
xy_pos, batched_k_cache, batched_v_cache, batched_decode_attn_mask = decode_batch_result
|
||||
decode_batch_tensor_summary = summarise_decode_batch(
|
||||
xy_pos,
|
||||
batched_k_cache,
|
||||
batched_v_cache,
|
||||
batched_decode_attn_mask,
|
||||
running_requests,
|
||||
)
|
||||
|
||||
decode_out_result, decode_step_mem = stage_run(
|
||||
device,
|
||||
lambda: model.t2s_transformer.decode_next_token(
|
||||
xy_pos,
|
||||
batched_k_cache,
|
||||
batched_v_cache,
|
||||
batched_decode_attn_mask,
|
||||
),
|
||||
)
|
||||
xy_dec, next_k_cache, next_v_cache = decode_out_result
|
||||
decode_output_tensor_summary = summarise_decode_outputs(xy_dec, next_k_cache, next_v_cache)
|
||||
del active_batch
|
||||
del running_requests
|
||||
del finished_items
|
||||
del xy_pos
|
||||
del batched_k_cache
|
||||
del batched_v_cache
|
||||
del batched_decode_attn_mask
|
||||
del xy_dec
|
||||
del next_k_cache
|
||||
del next_v_cache
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
end_to_end_worker = simulate_worker_end_to_end(
|
||||
tts=tts,
|
||||
specs=specs,
|
||||
max_steps=args.max_steps,
|
||||
rounds=args.worker_rounds,
|
||||
grad_mode=args.worker_grad_mode,
|
||||
)
|
||||
live_cuda_tensors_after_worker = snapshot_live_cuda_tensors()
|
||||
worker_inference_mode = None
|
||||
if args.compare_worker_grad_modes:
|
||||
gc.collect()
|
||||
if use_cuda:
|
||||
_sync_device(device)
|
||||
torch.cuda.empty_cache()
|
||||
worker_inference_mode = simulate_worker_end_to_end(
|
||||
tts=tts,
|
||||
specs=specs,
|
||||
max_steps=args.max_steps,
|
||||
rounds=args.worker_rounds,
|
||||
grad_mode="inference_mode",
|
||||
)
|
||||
|
||||
summary = {
|
||||
"meta": {
|
||||
"scenario": args.scenario if args.request_manifest is None else "manifest",
|
||||
"seed": int(args.seed),
|
||||
"device": str(device),
|
||||
"dtype": str(next(model.parameters()).dtype),
|
||||
"request_count": int(len(specs)),
|
||||
"num_layers": int(model.num_layers),
|
||||
"num_heads": int(model.num_head),
|
||||
"model_dim": int(model.model_dim),
|
||||
"model_weights_mb": bytes_to_mb(model_nbytes(model)),
|
||||
},
|
||||
"loaded_module_weights": build_module_weight_summary(tts),
|
||||
"requests": [
|
||||
{
|
||||
"request_id": spec.request_id,
|
||||
"ref_audio_path": str(spec.ref_audio_path),
|
||||
"prompt_text": spec.prompt_text,
|
||||
"text": spec.text,
|
||||
}
|
||||
for spec in specs
|
||||
],
|
||||
"prepare_stage": {
|
||||
"memory": prepare_mem,
|
||||
"request_state": request_state_summary,
|
||||
},
|
||||
"prefill_batch": {
|
||||
"memory": prefill_batch_mem,
|
||||
"tensor_bytes": prefill_batch_tensor_summary,
|
||||
},
|
||||
"prefill_step": {
|
||||
"memory": prefill_step_mem,
|
||||
"running_requests": running_requests_summary,
|
||||
"finished_after_prefill": finished_after_prefill_summary,
|
||||
},
|
||||
"decode_batch": {
|
||||
"memory": decode_batch_mem,
|
||||
"tensor_bytes": decode_batch_tensor_summary,
|
||||
},
|
||||
"decode_outputs": {
|
||||
"memory": decode_step_mem,
|
||||
"tensor_bytes": decode_output_tensor_summary,
|
||||
},
|
||||
"end_to_end_worker": end_to_end_worker,
|
||||
"live_cuda_tensors_after_worker": live_cuda_tensors_after_worker,
|
||||
"end_to_end_worker_inference_mode": worker_inference_mode,
|
||||
}
|
||||
summary["top_rankings"] = top_rankings(summary)
|
||||
|
||||
summary_path = args.output_dir / "t2s_memory_breakdown_summary.json"
|
||||
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
print(json.dumps(summary["meta"], ensure_ascii=False, indent=2))
|
||||
print("[top_rankings]")
|
||||
for item in summary["top_rankings"]:
|
||||
print(f"- {item['name']}: {item['mb']:.3f} MB")
|
||||
print("[worker_peak]")
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"peak_label": summary["end_to_end_worker"]["timeline"]["peak_label"],
|
||||
"peak_allocated_mb": summary["end_to_end_worker"]["timeline"]["peak_allocated_mb"],
|
||||
"peak_reserved_mb": summary["end_to_end_worker"]["timeline"]["peak_reserved_mb"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
if worker_inference_mode is not None:
|
||||
print("[worker_peak_inference_mode]")
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"peak_label": worker_inference_mode["timeline"]["peak_label"],
|
||||
"peak_allocated_mb": worker_inference_mode["timeline"]["peak_allocated_mb"],
|
||||
"peak_reserved_mb": worker_inference_mode["timeline"]["peak_reserved_mb"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
print(f"[summary] {summary_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
180
tools/t2s_scheduler_prototype.py
Normal file
180
tools/t2s_scheduler_prototype.py
Normal file
@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT_DIR) not in sys.path:
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
gpt_sovits_dir = ROOT_DIR / "GPT_SoVITS"
|
||||
if str(gpt_sovits_dir) not in sys.path:
|
||||
sys.path.append(str(gpt_sovits_dir))
|
||||
|
||||
from GPT_SoVITS.TTS_infer_pack.t2s_scheduler import ( # noqa: E402
|
||||
SchedulerRequestSpec,
|
||||
T2SFinishedItem,
|
||||
T2SRequestState,
|
||||
prepare_request_state,
|
||||
run_scheduler_continuous,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="T2S request-local scheduler prototype.")
|
||||
parser.add_argument("--config", type=Path, default=ROOT_DIR / "GPT_SoVITS/configs/tts_infer.yaml")
|
||||
parser.add_argument("--request-manifest", type=Path, default=None)
|
||||
parser.add_argument("--ref-audio", type=Path, default=ROOT_DIR / "test.wav")
|
||||
parser.add_argument("--prompt-text", type=str, default="是啊,主要是因为有调研需求的学者少了。")
|
||||
parser.add_argument("--prompt-lang", type=str, default="zh")
|
||||
parser.add_argument("--text-file", type=Path, default=ROOT_DIR / "test_en.txt")
|
||||
parser.add_argument("--text", type=str, default=None)
|
||||
parser.add_argument("--text-lang", type=str, default="en")
|
||||
parser.add_argument("--top-k", type=int, default=15)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.35)
|
||||
parser.add_argument("--early-stop-num", type=int, default=-1)
|
||||
parser.add_argument("--max-steps", type=int, default=1500)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
parser.add_argument("--output-dir", type=Path, default=ROOT_DIR / "TEMP/t2s_scheduler/output_run")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def set_seed(seed: int, use_cuda: bool) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if use_cuda and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def load_pipeline(config_path: Path):
|
||||
try:
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"缺少运行依赖,请先在 GPT-SoVITS 推理环境中安装 requirements 后再运行该脚本。"
|
||||
) from exc
|
||||
tts_config = TTS_Config(str(config_path))
|
||||
print(tts_config)
|
||||
return TTS(tts_config)
|
||||
|
||||
|
||||
def load_request_specs(args: argparse.Namespace) -> List[SchedulerRequestSpec]:
|
||||
if args.request_manifest is not None:
|
||||
payload = json.loads(args.request_manifest.read_text(encoding="utf-8"))
|
||||
raw_requests = payload["requests"] if isinstance(payload, dict) else payload
|
||||
specs: List[SchedulerRequestSpec] = []
|
||||
for index, item in enumerate(raw_requests):
|
||||
text = item.get("text")
|
||||
text_file = item.get("text_file")
|
||||
if text is None and text_file is None:
|
||||
raise ValueError(f"request[{index}] must provide text or text_file")
|
||||
if text is None:
|
||||
text = Path(text_file).read_text(encoding="utf-8")
|
||||
specs.append(
|
||||
SchedulerRequestSpec(
|
||||
request_id=item.get("request_id", f"req_{index:03d}"),
|
||||
ref_audio_path=Path(item["ref_audio_path"]),
|
||||
prompt_text=item["prompt_text"],
|
||||
prompt_lang=item.get("prompt_lang", "zh"),
|
||||
text=text,
|
||||
text_lang=item.get("text_lang", "zh"),
|
||||
top_k=int(item.get("top_k", args.top_k)),
|
||||
top_p=float(item.get("top_p", args.top_p)),
|
||||
temperature=float(item.get("temperature", args.temperature)),
|
||||
repetition_penalty=float(item.get("repetition_penalty", args.repetition_penalty)),
|
||||
early_stop_num=int(item.get("early_stop_num", args.early_stop_num)),
|
||||
ready_step=int(item.get("ready_step", 0)),
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
text = args.text if args.text is not None else args.text_file.read_text(encoding="utf-8")
|
||||
return [
|
||||
SchedulerRequestSpec(
|
||||
request_id="req_000",
|
||||
ref_audio_path=args.ref_audio,
|
||||
prompt_text=args.prompt_text,
|
||||
prompt_lang=args.prompt_lang,
|
||||
text=text,
|
||||
text_lang=args.text_lang,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
early_stop_num=args.early_stop_num,
|
||||
ready_step=0,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def summarise_requests(states: List[T2SRequestState]) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"request_id": state.request_id,
|
||||
"ready_step": int(state.ready_step),
|
||||
"ref_audio_path": str(state.ref_audio_path),
|
||||
"prompt_semantic_len": int(state.prompt_semantic.shape[0]),
|
||||
"all_phone_len": int(state.all_phones.shape[0]),
|
||||
"bert_len": int(state.all_bert_features.shape[-1]),
|
||||
"norm_text": state.norm_text,
|
||||
}
|
||||
for state in states
|
||||
]
|
||||
|
||||
|
||||
def summarise_finished(items: List[T2SFinishedItem]) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"request_id": item.request_id,
|
||||
"semantic_len": int(item.semantic_tokens.shape[0]),
|
||||
"finish_idx": int(item.finish_idx),
|
||||
"finish_reason": item.finish_reason,
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tts = load_pipeline(args.config)
|
||||
model = tts.t2s_model.model
|
||||
use_cuda = str(tts.configs.device).startswith("cuda")
|
||||
set_seed(args.seed, use_cuda)
|
||||
|
||||
request_specs = load_request_specs(args)
|
||||
states = [prepare_request_state(tts, spec) for spec in request_specs]
|
||||
finished = run_scheduler_continuous(model, states, max_steps=args.max_steps)
|
||||
|
||||
summary = {
|
||||
"request_count": len(states),
|
||||
"max_steps": args.max_steps,
|
||||
"requests": summarise_requests(states),
|
||||
"finished": summarise_finished(finished),
|
||||
}
|
||||
output_path = args.output_dir / "scheduler_prototype_summary.json"
|
||||
output_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps(summary, ensure_ascii=False, indent=2))
|
||||
print(f"[saved] {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except ModuleNotFoundError as exc:
|
||||
print(f"[error] {exc}")
|
||||
raise SystemExit(1) from None
|
||||
Loading…
x
Reference in New Issue
Block a user