mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-05-12 21:08:11 +08:00
Implement last inference statistics tracking in Text2SemanticDecoder and enhance TTS class with prompt semantic extraction. This includes methods for setting and retrieving inference stats, as well as improvements to audio processing and feature extraction in TTS.
This commit is contained in:
parent
b250e62402
commit
30a4557d8d
@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
blocks.append(block)
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
self.last_infer_stats = {}
|
||||
|
||||
def _set_last_infer_stats(self, stats):
|
||||
self.last_infer_stats = stats
|
||||
|
||||
def get_last_infer_stats(self):
|
||||
return dict(self.last_infer_stats)
|
||||
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||
x = self.ar_text_embedding(x)
|
||||
@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module):
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs,
|
||||
):
|
||||
requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True))
|
||||
if prompts is None:
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "batch_infer_prompt_free_fallback",
|
||||
"requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath,
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": None,
|
||||
"fastpath_hit": False,
|
||||
"generated_token_count": 0,
|
||||
"generated_token_count_list": [],
|
||||
}
|
||||
)
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(
|
||||
x,
|
||||
@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
)
|
||||
|
||||
max_len = kwargs.get("max_len", x_lens.max())
|
||||
enable_mask_free_fastpath = requested_enable_mask_free_fastpath
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None] * y.shape[0]
|
||||
decode_attn_mask = attn_mask
|
||||
prefill_after_mask_all_visible = None
|
||||
fastpath_hit = False
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(
|
||||
xy_pos, k_cache, v_cache, decode_attn_mask
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
prefill_after_mask_all_visible = not attn_mask.any().item()
|
||||
if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible:
|
||||
decode_attn_mask = None
|
||||
fastpath_hit = True
|
||||
else:
|
||||
decode_attn_mask = attn_mask
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
if decode_attn_mask is not None:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
decode_attn_mask = attn_mask
|
||||
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if decode_attn_mask is not None:
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
decode_attn_mask = attn_mask
|
||||
if k_cache is not None:
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "batch_infer",
|
||||
"requested_enable_mask_free_fastpath": enable_mask_free_fastpath,
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": prefill_after_mask_all_visible,
|
||||
"fastpath_hit": fastpath_hit,
|
||||
"generated_token_count": int(sum(idx_list)),
|
||||
"generated_token_count_list": [int(item) for item in idx_list],
|
||||
"max_len": int(max_len),
|
||||
}
|
||||
)
|
||||
if ref_free:
|
||||
return y_list, [0] * x.shape[0]
|
||||
# print(idx_list)
|
||||
@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "naive_batched",
|
||||
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
|
||||
"batch_size": int(len(x)),
|
||||
"prefill_after_mask_all_visible": None,
|
||||
"fastpath_hit": False,
|
||||
"generated_token_count": int(sum(idx_list)),
|
||||
"generated_token_count_list": [int(item) for item in idx_list],
|
||||
}
|
||||
)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive(
|
||||
@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
|
||||
if not streaming_mode:
|
||||
generated_token_count = max(int(y.shape[1] - prefix_len), 0)
|
||||
self._set_last_infer_stats(
|
||||
{
|
||||
"infer_mode": "naive",
|
||||
"requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)),
|
||||
"batch_size": int(x.shape[0]),
|
||||
"prefill_after_mask_all_visible": True if prompts is not None else None,
|
||||
"fastpath_hit": True if prompts is not None else False,
|
||||
"generated_token_count": generated_token_count,
|
||||
"generated_token_count_list": [generated_token_count],
|
||||
}
|
||||
)
|
||||
if ref_free:
|
||||
yield y, 0
|
||||
yield y, idx
|
||||
|
||||
@ -1227,6 +1227,9 @@ class TTS:
|
||||
###### inference ######
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
t2s_observe_batch_count = 0
|
||||
t2s_observe_fastpath_hits = 0
|
||||
t2s_observe_generated_tokens = 0
|
||||
audio = []
|
||||
is_first_package = True
|
||||
output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"]
|
||||
@ -1280,6 +1283,29 @@ class TTS:
|
||||
)
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
if hasattr(self.t2s_model.model, "get_last_infer_stats"):
|
||||
t2s_stats = self.t2s_model.model.get_last_infer_stats()
|
||||
if t2s_stats:
|
||||
generated_token_count = int(t2s_stats.get("generated_token_count", 0))
|
||||
t2s_total_ms = (t4 - t3) * 1000.0
|
||||
avg_decode_ms_per_token = (
|
||||
t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0
|
||||
)
|
||||
t2s_observe_batch_count += 1
|
||||
t2s_observe_generated_tokens += generated_token_count
|
||||
if bool(t2s_stats.get("fastpath_hit", False)):
|
||||
t2s_observe_fastpath_hits += 1
|
||||
print(
|
||||
"[t2s_observe] "
|
||||
f"mode={t2s_stats.get('infer_mode')} "
|
||||
f"batch_size={t2s_stats.get('batch_size')} "
|
||||
f"tokens={generated_token_count} "
|
||||
f"t2s_ms={t2s_total_ms:.3f} "
|
||||
f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} "
|
||||
f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} "
|
||||
f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} "
|
||||
f"fastpath_hit={t2s_stats.get('fastpath_hit')}"
|
||||
)
|
||||
|
||||
|
||||
batch_audio_fragment = []
|
||||
@ -1500,6 +1526,18 @@ class TTS:
|
||||
|
||||
if not (return_fragment or streaming_mode):
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
if t2s_observe_batch_count > 0:
|
||||
request_avg_decode_ms_per_token = (
|
||||
(t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0
|
||||
)
|
||||
print(
|
||||
"[t2s_request_observe] "
|
||||
f"batches={t2s_observe_batch_count} "
|
||||
f"fastpath_hits={t2s_observe_fastpath_hits} "
|
||||
f"generated_tokens={t2s_observe_generated_tokens} "
|
||||
f"t2s_total_ms={t_34 * 1000.0:.3f} "
|
||||
f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}"
|
||||
)
|
||||
if len(audio) == 0:
|
||||
yield output_sr, np.zeros(int(output_sr), dtype=np.int16)
|
||||
return
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user