From 705df4c41454a6f7da9998165b37094d2ba3e4ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E8=8F=9C=E5=B7=A5=E5=8E=821145=E5=8F=B7=E5=91=98?= =?UTF-8?q?=E5=B7=A5?= <114749500+baicai-1145@users.noreply.github.com> Date: Sun, 21 Sep 2025 08:23:17 +0800 Subject: [PATCH] Supporting word-level timestamp output through attention weight output --- GPT_SoVITS/inference_cli.py | 12 +++- GPT_SoVITS/inference_gui.py | 3 +- GPT_SoVITS/inference_webui.py | 106 +++++++++++++++++++++++++++++++--- GPT_SoVITS/module/__init__.py | 5 ++ GPT_SoVITS/module/commons.py | 17 ++++++ GPT_SoVITS/module/models.py | 61 +++++++++++++++++++ api.py | 38 +++++------- 7 files changed, 210 insertions(+), 32 deletions(-) diff --git a/GPT_SoVITS/inference_cli.py b/GPT_SoVITS/inference_cli.py index 459a3d36..cf2e00fb 100644 --- a/GPT_SoVITS/inference_cli.py +++ b/GPT_SoVITS/inference_cli.py @@ -44,10 +44,20 @@ def synthesize( result_list = list(synthesis_result) if result_list: - last_sampling_rate, last_audio_data = result_list[-1] + # new: ((sr, audio), timestamps) + last = result_list[-1] + (last_sampling_rate, last_audio_data), timestamps = last output_wav_path = os.path.join(output_path, "output.wav") sf.write(output_wav_path, last_audio_data, last_sampling_rate) print(f"Audio saved to {output_wav_path}") + # Optionally save timestamps + try: + import json + with open(os.path.join(output_path, "timestamps.json"), "w", encoding="utf-8") as f: + json.dump(timestamps, f, ensure_ascii=False, indent=2) + print("Timestamps saved to timestamps.json") + except Exception: + pass def main(): diff --git a/GPT_SoVITS/inference_gui.py b/GPT_SoVITS/inference_gui.py index 379f7fa8..a62f4acc 100644 --- a/GPT_SoVITS/inference_gui.py +++ b/GPT_SoVITS/inference_gui.py @@ -299,7 +299,8 @@ class GPTSoVITSGUI(QMainWindow): result_list = list(synthesis_result) if result_list: - last_sampling_rate, last_audio_data = result_list[-1] + last = result_list[-1] + last_sampling_rate, last_audio_data = last[0], last[1] output_wav_path = os.path.join(output_path, "output.wav") sf.write(output_wav_path, last_audio_data, last_sampling_rate) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index a361ed58..54d35301 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -849,6 +849,7 @@ def get_tts_wav( if not ref_free: phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) + timestamps_all = [] for i_text, text in enumerate(texts): # 解决输入目标文本的空行导致报错的问题 if len(text.strip()) == 0: @@ -912,14 +913,101 @@ def get_tts_wav( refers = [refers] if is_v2pro: sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] + # compute alignment and audio (v2/v2Pro only) + phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0) if is_v2pro: - audio = vq_model.decode( - pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb - )[0][0] + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb + ) else: - audio = vq_model.decode( - pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed - )[0][0] + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, phones2_tensor, refers, speed=speed + ) + audio = o[0][0] + last_attn = attn # [B,H,T_ssl,T_text] + last_y_mask = y_mask # [B,1,T_ssl] + # derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640) + if last_attn is not None: + attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text] + assign = attn_heads_mean.argmax(dim=-1) # [T_ssl] + frame_time = 0.02 / max(speed, 1e-6) + # collapse consecutive frames pointing to same phoneme id + ph_spans = [] + if assign.numel() > 0: + start_f = 0 + cur_ph = int(assign[0].item()) + for f in range(1, assign.shape[0]): + ph = int(assign[f].item()) + if ph != cur_ph: + ph_spans.append({ + "phoneme_id": cur_ph, + "start_s": start_f * frame_time, + "end_s": f * frame_time, + }) + start_f = f + cur_ph = ph + ph_spans.append({ + "phoneme_id": cur_ph, + "start_s": start_f * frame_time, + "end_s": assign.shape[0] * frame_time, + }) + # char/word aggregation + # obtain word2ph for current text segment + _, word2ph, norm_text_seg = clean_text_inf(text, text_language, version) + # char spans (for zh/yue where word2ph is char-based) + char_spans = [] + if word2ph: + ph_to_char = [] + for ch_idx, repeat in enumerate(word2ph): + ph_to_char += [ch_idx] * repeat + if ph_spans and ph_to_char: + for span in ph_spans: + ph_idx = span["phoneme_id"] + if 0 <= ph_idx < len(ph_to_char): + char_idx = ph_to_char[ph_idx] + if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx: + char_spans.append({ + "char_index": char_idx, + "char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", + "start_s": span["start_s"], + "end_s": span["end_s"], + }) + else: + char_spans[-1]["end_s"] = span["end_s"] + # word spans + word_spans = [] + if text_language == "en": + # build phoneme-to-word map using dictionary per-word g2p + try: + from text.english import g2p as g2p_en + except Exception: + g2p_en = None + words = norm_text_seg.split() + ph_to_word = [] + if g2p_en: + for w_idx, w in enumerate(words): + phs_w = g2p_en(w) + ph_to_word += [w_idx] * len(phs_w) + if ph_spans and ph_to_word: + for span in ph_spans: + ph_idx = span["phoneme_id"] + if 0 <= ph_idx < len(ph_to_word): + wi = ph_to_word[ph_idx] + if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi: + word_spans.append({ + "word_index": wi, + "word": words[wi] if wi < len(words) else "", + "start_s": span["start_s"], + "end_s": span["end_s"], + }) + else: + word_spans[-1]["end_s"] = span["end_s"] + timestamps_all.append({ + "segment_index": i_text, + "phoneme_spans": ph_spans, + "char_spans": char_spans, + "word_spans": word_spans, + }) else: refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) @@ -998,7 +1086,8 @@ def get_tts_wav( audio_opt /= max_audio else: audio_opt = audio_opt.cpu().detach().numpy() - yield opt_sr, (audio_opt * 32767).astype(np.int16) + # Return audio and timestamps for UI consumption: ((sr, audio), timestamps) + yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all def split(todo_text): @@ -1285,6 +1374,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css with gr.Row(): inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25) output = gr.Audio(label=i18n("输出的语音"), scale=14) + timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)")) inference_button.click( get_tts_wav, @@ -1306,7 +1396,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css if_sr_Checkbox, pause_second_slider, ], - [output], + [output, timestamps_box], ) SoVITS_dropdown.change( change_sovits_weights, diff --git a/GPT_SoVITS/module/__init__.py b/GPT_SoVITS/module/__init__.py index e69de29b..ecb88a41 100644 --- a/GPT_SoVITS/module/__init__.py +++ b/GPT_SoVITS/module/__init__.py @@ -0,0 +1,5 @@ +from .models import SynthesizerTrn + +__all__ = ["SynthesizerTrn"] + + diff --git a/GPT_SoVITS/module/commons.py b/GPT_SoVITS/module/commons.py index 20392f91..d8521adc 100644 --- a/GPT_SoVITS/module/commons.py +++ b/GPT_SoVITS/module/commons.py @@ -93,6 +93,23 @@ def subsequent_mask(length): return mask +def attn_to_token_alignment(attn, reduce="mean"): + """ + Convert attention [B, H, T_tgt, T_src] to argmax alignment per target step. + Returns [B, T_tgt] indices. + """ + if attn is None: + return None + if reduce == "mean": + attn_ = attn.mean(dim=1) + elif reduce == "max": + attn_ = attn.max(dim=1).values + else: + attn_ = attn + return attn_.argmax(dim=-1) + + + @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): n_channels_int = n_channels[0] diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 1c8e662f..346e3df7 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1004,6 +1004,67 @@ class SynthesizerTrn(nn.Module): o = self.dec((z * y_mask)[:, :, :], g=ge) return o + @torch.no_grad() + def decode_with_alignment(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None): + """ + Decode and also return cross-attention alignment from MRTE (ssl↔text). + + Returns: + o (Tensor): waveform tensor [B, 1, T_audio] + attn (Tensor): attention weights [B, n_heads, T_ssl, T_text] + y_mask (Tensor): ssl/time mask [B, 1, T_ssl] + """ + def get_ge(refer, sv_emb): + ge = None + if refer is not None: + refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) + refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + if self.version == "v1": + ge = self.ref_enc(refer * refer_mask, refer_mask) + else: + ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + return ge + + if type(refer) == list: + ges = [] + for idx, _refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None) + ges.append(ge) + ge = torch.stack(ges, 0).mean(0) + else: + ge = get_ge(refer, sv_emb) + + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) + text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) + + quantized = self.quantizer.decode(codes) + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + x, m_p, logs_p, y_mask = self.enc_p( + quantized, + y_lengths, + text, + text_lengths, + self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge, + speed, + ) + + # MRTE cross-attention collected during enc_p forward + # Shape: [B, n_heads, T_ssl, T_text] + try: + attn = self.enc_p.mrte.cross_attention.attn + except Exception: + attn = None + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=ge, reverse=True) + o = self.dec((z * y_mask)[:, :, :], g=ge) + return o, attn, y_mask + def extract_latent(self, x): ssl = self.ssl_proj(x) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) diff --git a/api.py b/api.py index cc0896a2..5a2035c1 100644 --- a/api.py +++ b/api.py @@ -948,27 +948,19 @@ def get_tts_wav( if version not in {"v3", "v4"}: if is_v2pro: - audio = ( - vq_model.decode( - pred_semantic, - torch.LongTensor(phones2).to(device).unsqueeze(0), - refers, - speed=speed, - sv_emb=sv_emb, - ) - .detach() - .cpu() - .numpy()[0, 0] + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, + torch.LongTensor(phones2).to(device).unsqueeze(0), + refers, + speed=speed, + sv_emb=sv_emb, ) + audio = o.detach().cpu().numpy()[0, 0] else: - audio = ( - vq_model.decode( - pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed - ) - .detach() - .cpu() - .numpy()[0, 0] + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed ) + audio = o.detach().cpu().numpy()[0, 0] else: phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) @@ -1054,6 +1046,7 @@ def get_tts_wav( # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) if stream_mode == "normal": audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) + # For backward compatibility, yield audio chunk only yield audio_chunk if not stream_mode == "normal": @@ -1065,6 +1058,7 @@ def get_tts_wav( else: sr = 48000 # v4 audio_bytes = pack_wav(audio_bytes, sr) + # extend: for stream_mode!=normal we still return only audio bytes for backward compatibility yield audio_bytes.getvalue() @@ -1133,8 +1127,7 @@ def handle( else: text = cut_text(text, cut_punc) - return StreamingResponse( - get_tts_wav( + gen = get_tts_wav( refer_wav_path, prompt_text, prompt_language, @@ -1147,9 +1140,10 @@ def handle( inp_refs, sample_steps, if_sr, - ), - media_type="audio/" + media_type, ) + # Consume the generator to collect bytes and timestamps; pack as multipart/mixed JSON with audio for compatibility + # For simplicity, keep legacy streaming behaviour + return StreamingResponse(gen, media_type="audio/" + media_type) # --------------------------------