diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ac905f4b..f55b7508 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -351,6 +351,13 @@ class Text2SemanticDecoder(nn.Module): blocks.append(block) self.t2s_transformer = T2STransformer(self.num_layers, blocks) + self.last_infer_stats = {} + + def _set_last_infer_stats(self, stats): + self.last_infer_stats = stats + + def get_last_infer_stats(self): + return dict(self.last_infer_stats) def make_input_data(self, x, x_lens, y, y_lens, bert_feature): x = self.ar_text_embedding(x) @@ -593,7 +600,19 @@ class Text2SemanticDecoder(nn.Module): repetition_penalty: float = 1.35, **kwargs, ): + requested_enable_mask_free_fastpath = bool(kwargs.get("enable_mask_free_fastpath", True)) if prompts is None: + self._set_last_infer_stats( + { + "infer_mode": "batch_infer_prompt_free_fallback", + "requested_enable_mask_free_fastpath": requested_enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": 0, + "generated_token_count_list": [], + } + ) print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") return self.infer_panel_naive_batched( x, @@ -608,6 +627,7 @@ class Text2SemanticDecoder(nn.Module): ) max_len = kwargs.get("max_len", x_lens.max()) + enable_mask_free_fastpath = requested_enable_mask_free_fastpath x_list = [] for x_item, bert_item in zip(x, bert_feature): # max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) @@ -698,17 +718,30 @@ class Text2SemanticDecoder(nn.Module): y_list = [None] * y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None] * y.shape[0] + decode_attn_mask = attn_mask + prefill_after_mask_all_visible = None + fastpath_hit = False for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token( + xy_pos, k_cache, v_cache, decode_attn_mask + ) logits = self.ar_predict_layer(xy_dec[:, -1]) if idx == 0: attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False) + prefill_after_mask_all_visible = not attn_mask.any().item() + if enable_mask_free_fastpath and y.shape[0] == 1 and prefill_after_mask_all_visible: + decode_attn_mask = None + fastpath_hit = True + else: + decode_attn_mask = attn_mask else: - attn_mask = F.pad(attn_mask, (0, 1), value=False) + if decode_attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1), value=False) + decode_attn_mask = attn_mask if idx < 11: ###至少预测出10个token不然不给停止(0.4s) logits = logits[:, :-1] @@ -740,7 +773,9 @@ class Text2SemanticDecoder(nn.Module): if reserved_idx_of_batch_for_y is not None: # index = torch.LongTensor(batch_idx_map).to(y.device) y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) - attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + if decode_attn_mask is not None: + attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) + decode_attn_mask = attn_mask if k_cache is not None: for i in range(len(k_cache)): k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) @@ -775,6 +810,18 @@ class Text2SemanticDecoder(nn.Module): if idx_list[i] is None: idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替 + self._set_last_infer_stats( + { + "infer_mode": "batch_infer", + "requested_enable_mask_free_fastpath": enable_mask_free_fastpath, + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": prefill_after_mask_all_visible, + "fastpath_hit": fastpath_hit, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + "max_len": int(max_len), + } + ) if ref_free: return y_list, [0] * x.shape[0] # print(idx_list) @@ -811,6 +858,17 @@ class Text2SemanticDecoder(nn.Module): y_list.append(y[0]) idx_list.append(idx) + self._set_last_infer_stats( + { + "infer_mode": "naive_batched", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(len(x)), + "prefill_after_mask_all_visible": None, + "fastpath_hit": False, + "generated_token_count": int(sum(idx_list)), + "generated_token_count_list": [int(item) for item in idx_list], + } + ) return y_list, idx_list def infer_panel_naive( @@ -957,6 +1015,18 @@ class Text2SemanticDecoder(nn.Module): if not streaming_mode: + generated_token_count = max(int(y.shape[1] - prefix_len), 0) + self._set_last_infer_stats( + { + "infer_mode": "naive", + "requested_enable_mask_free_fastpath": bool(kwargs.get("enable_mask_free_fastpath", True)), + "batch_size": int(x.shape[0]), + "prefill_after_mask_all_visible": True if prompts is not None else None, + "fastpath_hit": True if prompts is not None else False, + "generated_token_count": generated_token_count, + "generated_token_count_list": [generated_token_count], + } + ) if ref_free: yield y, 0 yield y, idx diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9c8344b0..e1efd973 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1227,6 +1227,9 @@ class TTS: ###### inference ###### t_34 = 0.0 t_45 = 0.0 + t2s_observe_batch_count = 0 + t2s_observe_fastpath_hits = 0 + t2s_observe_generated_tokens = 0 audio = [] is_first_package = True output_sr = self.configs.sampling_rate if not self.configs.use_vocoder else self.vocoder_configs["sr"] @@ -1280,6 +1283,29 @@ class TTS: ) t4 = time.perf_counter() t_34 += t4 - t3 + if hasattr(self.t2s_model.model, "get_last_infer_stats"): + t2s_stats = self.t2s_model.model.get_last_infer_stats() + if t2s_stats: + generated_token_count = int(t2s_stats.get("generated_token_count", 0)) + t2s_total_ms = (t4 - t3) * 1000.0 + avg_decode_ms_per_token = ( + t2s_total_ms / generated_token_count if generated_token_count > 0 else 0.0 + ) + t2s_observe_batch_count += 1 + t2s_observe_generated_tokens += generated_token_count + if bool(t2s_stats.get("fastpath_hit", False)): + t2s_observe_fastpath_hits += 1 + print( + "[t2s_observe] " + f"mode={t2s_stats.get('infer_mode')} " + f"batch_size={t2s_stats.get('batch_size')} " + f"tokens={generated_token_count} " + f"t2s_ms={t2s_total_ms:.3f} " + f"avg_decode_ms_per_token={avg_decode_ms_per_token:.3f} " + f"requested_fastpath={t2s_stats.get('requested_enable_mask_free_fastpath')} " + f"prefill_all_visible={t2s_stats.get('prefill_after_mask_all_visible')} " + f"fastpath_hit={t2s_stats.get('fastpath_hit')}" + ) batch_audio_fragment = [] @@ -1500,6 +1526,18 @@ class TTS: if not (return_fragment or streaming_mode): print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) + if t2s_observe_batch_count > 0: + request_avg_decode_ms_per_token = ( + (t_34 * 1000.0) / t2s_observe_generated_tokens if t2s_observe_generated_tokens > 0 else 0.0 + ) + print( + "[t2s_request_observe] " + f"batches={t2s_observe_batch_count} " + f"fastpath_hits={t2s_observe_fastpath_hits} " + f"generated_tokens={t2s_observe_generated_tokens} " + f"t2s_total_ms={t_34 * 1000.0:.3f} " + f"avg_decode_ms_per_token={request_avg_decode_ms_per_token:.3f}" + ) if len(audio) == 0: yield output_sr, np.zeros(int(output_sr), dtype=np.int16) return