Implement last inference statistics tracking in Text2SemanticDecoder and enhance TTS class with prompt semantic extraction. This includes methods for setting and retrieving inference stats, as well as improvements to audio processing and feature extraction in TTS.

This commit is contained in:
baicai-1145 2026-03-08 23:08:27 +08:00
parent b250e62402
commit 30a4557d8d
2 changed files with 111 additions and 3 deletions

View File

@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module):
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
self.last_infer_stats = {}
def _set_last_infer_stats(self, stats):
self.last_infer_stats = stats
def get_last_infer_stats(self):
return dict(self.last_infer_stats)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
x = self.ar_text_embedding(x)
@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module):
repetition_penalty: float = 1.35,
**kwargs,
):
requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True))
if prompts is None:
self._set_last_infer_stats(
{
"infer_mode": "batch_infer_prompt_free_fallback",
"requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath,
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": None,
"fastpath_hit": False,
"generated_token_count": 0,
"generated_token_count_list": [],
}
)
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(
x,
@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module):
)
max_len = kwargs.get("max_len", x_lens.max())
enable_mask_free_fastpath = requested_enable_mask_free_fastpath
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module):
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None] * y.shape[0]
decode_attn_mask = attn_mask
prefill_after_mask_all_visible = None
fastpath_hit = False
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(
xy_pos, k_cache, v_cache, decode_attn_mask
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
prefill_after_mask_all_visible = not attn_mask.any().item()
if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible:
decode_attn_mask = None
fastpath_hit = True
else:
decode_attn_mask = attn_mask
else:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
if decode_attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1), value=False)
decode_attn_mask = attn_mask
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module):
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if decode_attn_mask is not None:
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
decode_attn_mask = attn_mask
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module):
if idx_list[i] is None:
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
self._set_last_infer_stats(
{
"infer_mode": "batch_infer",
"requested_enable_mask_free_fastpath": enable_mask_free_fastpath,
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": prefill_after_mask_all_visible,
"fastpath_hit": fastpath_hit,
"generated_token_count": int(sum(idx_list)),
"generated_token_count_list": [int(item) for item in idx_list],
"max_len": int(max_len),
}
)
if ref_free:
return y_list, [0] * x.shape[0]
# print(idx_list)
@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module):
y_list.append(y[0])
idx_list.append(idx)
self._set_last_infer_stats(
{
"infer_mode": "naive_batched",
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
"batch_size": int(len(x)),
"prefill_after_mask_all_visible": None,
"fastpath_hit": False,
"generated_token_count": int(sum(idx_list)),
"generated_token_count_list": [int(item) for item in idx_list],
}
)
return y_list, idx_list
def infer_panel_naive(
@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module):
if not streaming_mode:
generated_token_count = max(int(y.shape[1] - prefix_len), 0)
self._set_last_infer_stats(
{
"infer_mode": "naive",
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
"batch_size": int(x.shape[0]),
"prefill_after_mask_all_visible": True if prompts is not None else None,
"fastpath_hit": True if prompts is not None else False,
"generated_token_count": generated_token_count,
"generated_token_count_list": [generated_token_count],
}
)
if ref_free:
yield y, 0
yield y, idx

View File

@ -1227,6 +1227,9 @@ class TTS:
###### inference ######
t_34 = 0.0
t_45 = 0.0
t2s_observe_batch_count = 0
t2s_observe_fastpath_hits = 0
t2s_observe_generated_tokens = 0
audio = []
is_first_package = True
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
@ -1280,6 +1283,29 @@ class TTS:
)
t4 = time.perf_counter()
t_34 += t4 - t3
if hasattr(self.t2s_model.model, "get_last_infer_stats"):
t2s_stats = self.t2s_model.model.get_last_infer_stats()
if t2s_stats:
generated_token_count = int(t2s_stats.get("generated_token_count", 0))
t2s_total_ms = (t4 - t3) * 1000.0
avg_decode_ms_per_token = (
t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0
)
t2s_observe_batch_count += 1
t2s_observe_generated_tokens += generated_token_count
if bool(t2s_stats.get("fastpath_hit", False)):
t2s_observe_fastpath_hits += 1
print(
"[t2s_observe] "
f"mode={t2s_stats.get('infer_mode')} "
f"batch_size={t2s_stats.get('batch_size')} "
f"tokens={generated_token_count} "
f"t2s_ms={t2s_total_ms:.3f} "
f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} "
f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} "
f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} "
f"fastpath_hit={t2s_stats.get('fastpath_hit')}"
)
batch_audio_fragment = []
@ -1500,6 +1526,18 @@ class TTS:
if not (return_fragment or streaming_mode):
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
if t2s_observe_batch_count > 0:
request_avg_decode_ms_per_token = (
(t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0
)
print(
"[t2s_request_observe] "
f"batches={t2s_observe_batch_count} "
f"fastpath_hits={t2s_observe_fastpath_hits} "
f"generated_tokens={t2s_observe_generated_tokens} "
f"t2s_total_ms={t_34 * 1000.0:.3f} "
f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}"
)
if len(audio) == 0:
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
return