Supporting word-level timestamp output through attention weight output

This commit is contained in:
白菜工厂1145号员工 2025-09-21 08:23:17 +08:00
parent 11aa78bd9b
commit 705df4c414
7 changed files with 210 additions and 32 deletions

View File

@ -44,10 +44,20 @@ def synthesize(
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
# new: ((sr, audio), timestamps)
last = result_list[-1]
(last_sampling_rate, last_audio_data), timestamps = last
output_wav_path = os.path.join(output_path, "output.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
print(f"Audio saved to {output_wav_path}")
# Optionally save timestamps
try:
import json
with open(os.path.join(output_path, "timestamps.json"), "w", encoding="utf-8") as f:
json.dump(timestamps, f, ensure_ascii=False, indent=2)
print("Timestamps saved to timestamps.json")
except Exception:
pass
def main():

View File

@ -299,7 +299,8 @@ class GPTSoVITSGUI(QMainWindow):
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
last = result_list[-1]
last_sampling_rate, last_audio_data = last[0], last[1]
output_wav_path = os.path.join(output_path, "output.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)

View File

@ -849,6 +849,7 @@ def get_tts_wav(
if not ref_free:
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
timestamps_all = []
for i_text, text in enumerate(texts):
# 解决输入目标文本的空行导致报错的问题
if len(text.strip()) == 0:
@ -912,14 +913,101 @@ def get_tts_wav(
refers = [refers]
if is_v2pro:
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
# compute alignment and audio (v2/v2Pro only)
phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0)
if is_v2pro:
audio = vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb
)[0][0]
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb
)
else:
audio = vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
)[0][0]
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic, phones2_tensor, refers, speed=speed
)
audio = o[0][0]
last_attn = attn # [B,H,T_ssl,T_text]
last_y_mask = y_mask # [B,1,T_ssl]
# derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640)
if last_attn is not None:
attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text]
assign = attn_heads_mean.argmax(dim=-1) # [T_ssl]
frame_time = 0.02 / max(speed, 1e-6)
# collapse consecutive frames pointing to same phoneme id
ph_spans = []
if assign.numel() > 0:
start_f = 0
cur_ph = int(assign[0].item())
for f in range(1, assign.shape[0]):
ph = int(assign[f].item())
if ph != cur_ph:
ph_spans.append({
"phoneme_id": cur_ph,
"start_s": start_f * frame_time,
"end_s": f * frame_time,
})
start_f = f
cur_ph = ph
ph_spans.append({
"phoneme_id": cur_ph,
"start_s": start_f * frame_time,
"end_s": assign.shape[0] * frame_time,
})
# char/word aggregation
# obtain word2ph for current text segment
_, word2ph, norm_text_seg = clean_text_inf(text, text_language, version)
# char spans (for zh/yue where word2ph is char-based)
char_spans = []
if word2ph:
ph_to_char = []
for ch_idx, repeat in enumerate(word2ph):
ph_to_char += [ch_idx] * repeat
if ph_spans and ph_to_char:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_char):
char_idx = ph_to_char[ph_idx]
if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx:
char_spans.append({
"char_index": char_idx,
"char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
char_spans[-1]["end_s"] = span["end_s"]
# word spans
word_spans = []
if text_language == "en":
# build phoneme-to-word map using dictionary per-word g2p
try:
from text.english import g2p as g2p_en
except Exception:
g2p_en = None
words = norm_text_seg.split()
ph_to_word = []
if g2p_en:
for w_idx, w in enumerate(words):
phs_w = g2p_en(w)
ph_to_word += [w_idx] * len(phs_w)
if ph_spans and ph_to_word:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_word):
wi = ph_to_word[ph_idx]
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
word_spans.append({
"word_index": wi,
"word": words[wi] if wi < len(words) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
word_spans[-1]["end_s"] = span["end_s"]
timestamps_all.append({
"segment_index": i_text,
"phoneme_spans": ph_spans,
"char_spans": char_spans,
"word_spans": word_spans,
})
else:
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
@ -998,7 +1086,8 @@ def get_tts_wav(
audio_opt /= max_audio
else:
audio_opt = audio_opt.cpu().detach().numpy()
yield opt_sr, (audio_opt * 32767).astype(np.int16)
# Return audio and timestamps for UI consumption: ((sr, audio), timestamps)
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all
def split(todo_text):
@ -1285,6 +1374,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Row():
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
output = gr.Audio(label=i18n("输出的语音"), scale=14)
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)"))
inference_button.click(
get_tts_wav,
@ -1306,7 +1396,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
if_sr_Checkbox,
pause_second_slider,
],
[output],
[output, timestamps_box],
)
SoVITS_dropdown.change(
change_sovits_weights,

View File

@ -0,0 +1,5 @@
from .models import SynthesizerTrn
__all__ = ["SynthesizerTrn"]

View File

@ -93,6 +93,23 @@ def subsequent_mask(length):
return mask
def attn_to_token_alignment(attn, reduce="mean"):
"""
Convert attention [B, H, T_tgt, T_src] to argmax alignment per target step.
Returns [B, T_tgt] indices.
"""
if attn is None:
return None
if reduce == "mean":
attn_ = attn.mean(dim=1)
elif reduce == "max":
attn_ = attn.max(dim=1).values
else:
attn_ = attn
return attn_.argmax(dim=-1)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]

View File

@ -1004,6 +1004,67 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def decode_with_alignment(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
"""
Decode and also return cross-attention alignment from MRTE (ssltext).
Returns:
o (Tensor): waveform tensor [B, 1, T_audio]
attn (Tensor): attention weights [B, n_heads, T_ssl, T_text]
y_mask (Tensor): ssl/time mask [B, 1, T_ssl]
"""
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
return ge
if type(refer) == list:
ges = []
for idx, _refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
ges.append(ge)
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer, sv_emb)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
# MRTE cross-attention collected during enc_p forward
# Shape: [B, n_heads, T_ssl, T_text]
try:
attn = self.enc_p.mrte.cross_attention.attn
except Exception:
attn = None
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, attn, y_mask
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)

38
api.py
View File

@ -948,27 +948,19 @@ def get_tts_wav(
if version not in {"v3", "v4"}:
if is_v2pro:
audio = (
vq_model.decode(
pred_semantic,
torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,
speed=speed,
sv_emb=sv_emb,
)
.detach()
.cpu()
.numpy()[0, 0]
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic,
torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,
speed=speed,
sv_emb=sv_emb,
)
audio = o.detach().cpu().numpy()[0, 0]
else:
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
)
.detach()
.cpu()
.numpy()[0, 0]
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
)
audio = o.detach().cpu().numpy()[0, 0]
else:
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
@ -1054,6 +1046,7 @@ def get_tts_wav(
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal":
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
# For backward compatibility, yield audio chunk only
yield audio_chunk
if not stream_mode == "normal":
@ -1065,6 +1058,7 @@ def get_tts_wav(
else:
sr = 48000 # v4
audio_bytes = pack_wav(audio_bytes, sr)
# extend: for stream_mode!=normal we still return only audio bytes for backward compatibility
yield audio_bytes.getvalue()
@ -1133,8 +1127,7 @@ def handle(
else:
text = cut_text(text, cut_punc)
return StreamingResponse(
get_tts_wav(
gen = get_tts_wav(
refer_wav_path,
prompt_text,
prompt_language,
@ -1147,9 +1140,10 @@ def handle(
inp_refs,
sample_steps,
if_sr,
),
media_type="audio/" + media_type,
)
# Consume the generator to collect bytes and timestamps; pack as multipart/mixed JSON with audio for compatibility
# For simplicity, keep legacy streaming behaviour
return StreamingResponse(gen, media_type="audio/" + media_type)
# --------------------------------