mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 05:36:34 +08:00
Supporting word-level timestamp output through attention weight output
This commit is contained in:
parent
11aa78bd9b
commit
705df4c414
@ -44,10 +44,20 @@ def synthesize(
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
# new: ((sr, audio), timestamps)
|
||||
last = result_list[-1]
|
||||
(last_sampling_rate, last_audio_data), timestamps = last
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
print(f"Audio saved to {output_wav_path}")
|
||||
# Optionally save timestamps
|
||||
try:
|
||||
import json
|
||||
with open(os.path.join(output_path, "timestamps.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(timestamps, f, ensure_ascii=False, indent=2)
|
||||
print("Timestamps saved to timestamps.json")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -299,7 +299,8 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
last = result_list[-1]
|
||||
last_sampling_rate, last_audio_data = last[0], last[1]
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
|
||||
|
@ -849,6 +849,7 @@ def get_tts_wav(
|
||||
if not ref_free:
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
|
||||
timestamps_all = []
|
||||
for i_text, text in enumerate(texts):
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if len(text.strip()) == 0:
|
||||
@ -912,14 +913,101 @@ def get_tts_wav(
|
||||
refers = [refers]
|
||||
if is_v2pro:
|
||||
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
|
||||
# compute alignment and audio (v2/v2Pro only)
|
||||
phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
if is_v2pro:
|
||||
audio = vq_model.decode(
|
||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb
|
||||
)[0][0]
|
||||
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||
pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb
|
||||
)
|
||||
else:
|
||||
audio = vq_model.decode(
|
||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
||||
)[0][0]
|
||||
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||
pred_semantic, phones2_tensor, refers, speed=speed
|
||||
)
|
||||
audio = o[0][0]
|
||||
last_attn = attn # [B,H,T_ssl,T_text]
|
||||
last_y_mask = y_mask # [B,1,T_ssl]
|
||||
# derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640)
|
||||
if last_attn is not None:
|
||||
attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text]
|
||||
assign = attn_heads_mean.argmax(dim=-1) # [T_ssl]
|
||||
frame_time = 0.02 / max(speed, 1e-6)
|
||||
# collapse consecutive frames pointing to same phoneme id
|
||||
ph_spans = []
|
||||
if assign.numel() > 0:
|
||||
start_f = 0
|
||||
cur_ph = int(assign[0].item())
|
||||
for f in range(1, assign.shape[0]):
|
||||
ph = int(assign[f].item())
|
||||
if ph != cur_ph:
|
||||
ph_spans.append({
|
||||
"phoneme_id": cur_ph,
|
||||
"start_s": start_f * frame_time,
|
||||
"end_s": f * frame_time,
|
||||
})
|
||||
start_f = f
|
||||
cur_ph = ph
|
||||
ph_spans.append({
|
||||
"phoneme_id": cur_ph,
|
||||
"start_s": start_f * frame_time,
|
||||
"end_s": assign.shape[0] * frame_time,
|
||||
})
|
||||
# char/word aggregation
|
||||
# obtain word2ph for current text segment
|
||||
_, word2ph, norm_text_seg = clean_text_inf(text, text_language, version)
|
||||
# char spans (for zh/yue where word2ph is char-based)
|
||||
char_spans = []
|
||||
if word2ph:
|
||||
ph_to_char = []
|
||||
for ch_idx, repeat in enumerate(word2ph):
|
||||
ph_to_char += [ch_idx] * repeat
|
||||
if ph_spans and ph_to_char:
|
||||
for span in ph_spans:
|
||||
ph_idx = span["phoneme_id"]
|
||||
if 0 <= ph_idx < len(ph_to_char):
|
||||
char_idx = ph_to_char[ph_idx]
|
||||
if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx:
|
||||
char_spans.append({
|
||||
"char_index": char_idx,
|
||||
"char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "",
|
||||
"start_s": span["start_s"],
|
||||
"end_s": span["end_s"],
|
||||
})
|
||||
else:
|
||||
char_spans[-1]["end_s"] = span["end_s"]
|
||||
# word spans
|
||||
word_spans = []
|
||||
if text_language == "en":
|
||||
# build phoneme-to-word map using dictionary per-word g2p
|
||||
try:
|
||||
from text.english import g2p as g2p_en
|
||||
except Exception:
|
||||
g2p_en = None
|
||||
words = norm_text_seg.split()
|
||||
ph_to_word = []
|
||||
if g2p_en:
|
||||
for w_idx, w in enumerate(words):
|
||||
phs_w = g2p_en(w)
|
||||
ph_to_word += [w_idx] * len(phs_w)
|
||||
if ph_spans and ph_to_word:
|
||||
for span in ph_spans:
|
||||
ph_idx = span["phoneme_id"]
|
||||
if 0 <= ph_idx < len(ph_to_word):
|
||||
wi = ph_to_word[ph_idx]
|
||||
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
|
||||
word_spans.append({
|
||||
"word_index": wi,
|
||||
"word": words[wi] if wi < len(words) else "",
|
||||
"start_s": span["start_s"],
|
||||
"end_s": span["end_s"],
|
||||
})
|
||||
else:
|
||||
word_spans[-1]["end_s"] = span["end_s"]
|
||||
timestamps_all.append({
|
||||
"segment_index": i_text,
|
||||
"phoneme_spans": ph_spans,
|
||||
"char_spans": char_spans,
|
||||
"word_spans": word_spans,
|
||||
})
|
||||
else:
|
||||
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
@ -998,7 +1086,8 @@ def get_tts_wav(
|
||||
audio_opt /= max_audio
|
||||
else:
|
||||
audio_opt = audio_opt.cpu().detach().numpy()
|
||||
yield opt_sr, (audio_opt * 32767).astype(np.int16)
|
||||
# Return audio and timestamps for UI consumption: ((sr, audio), timestamps)
|
||||
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
@ -1285,6 +1374,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
with gr.Row():
|
||||
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
||||
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
||||
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)"))
|
||||
|
||||
inference_button.click(
|
||||
get_tts_wav,
|
||||
@ -1306,7 +1396,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
if_sr_Checkbox,
|
||||
pause_second_slider,
|
||||
],
|
||||
[output],
|
||||
[output, timestamps_box],
|
||||
)
|
||||
SoVITS_dropdown.change(
|
||||
change_sovits_weights,
|
||||
|
@ -0,0 +1,5 @@
|
||||
from .models import SynthesizerTrn
|
||||
|
||||
__all__ = ["SynthesizerTrn"]
|
||||
|
||||
|
@ -93,6 +93,23 @@ def subsequent_mask(length):
|
||||
return mask
|
||||
|
||||
|
||||
def attn_to_token_alignment(attn, reduce="mean"):
|
||||
"""
|
||||
Convert attention [B, H, T_tgt, T_src] to argmax alignment per target step.
|
||||
Returns [B, T_tgt] indices.
|
||||
"""
|
||||
if attn is None:
|
||||
return None
|
||||
if reduce == "mean":
|
||||
attn_ = attn.mean(dim=1)
|
||||
elif reduce == "max":
|
||||
attn_ = attn.max(dim=1).values
|
||||
else:
|
||||
attn_ = attn
|
||||
return attn_.argmax(dim=-1)
|
||||
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
|
@ -1004,6 +1004,67 @@ class SynthesizerTrn(nn.Module):
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_with_alignment(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
"""
|
||||
Decode and also return cross-attention alignment from MRTE (ssl↔text).
|
||||
|
||||
Returns:
|
||||
o (Tensor): waveform tensor [B, 1, T_audio]
|
||||
attn (Tensor): attention weights [B, n_heads, T_ssl, T_text]
|
||||
y_mask (Tensor): ssl/time mask [B, 1, T_ssl]
|
||||
"""
|
||||
def get_ge(refer, sv_emb):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
||||
ge += sv_emb.unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
return ge
|
||||
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for idx, _refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||
ges.append(ge)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
else:
|
||||
ge = get_ge(refer, sv_emb)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
)
|
||||
|
||||
# MRTE cross-attention collected during enc_p forward
|
||||
# Shape: [B, n_heads, T_ssl, T_text]
|
||||
try:
|
||||
attn = self.enc_p.mrte.cross_attention.attn
|
||||
except Exception:
|
||||
attn = None
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||
return o, attn, y_mask
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
|
38
api.py
38
api.py
@ -948,27 +948,19 @@ def get_tts_wav(
|
||||
|
||||
if version not in {"v3", "v4"}:
|
||||
if is_v2pro:
|
||||
audio = (
|
||||
vq_model.decode(
|
||||
pred_semantic,
|
||||
torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refers,
|
||||
speed=speed,
|
||||
sv_emb=sv_emb,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||
pred_semantic,
|
||||
torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refers,
|
||||
speed=speed,
|
||||
sv_emb=sv_emb,
|
||||
)
|
||||
audio = o.detach().cpu().numpy()[0, 0]
|
||||
else:
|
||||
audio = (
|
||||
vq_model.decode(
|
||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
||||
)
|
||||
audio = o.detach().cpu().numpy()[0, 0]
|
||||
else:
|
||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
@ -1054,6 +1046,7 @@ def get_tts_wav(
|
||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
if stream_mode == "normal":
|
||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||
# For backward compatibility, yield audio chunk only
|
||||
yield audio_chunk
|
||||
|
||||
if not stream_mode == "normal":
|
||||
@ -1065,6 +1058,7 @@ def get_tts_wav(
|
||||
else:
|
||||
sr = 48000 # v4
|
||||
audio_bytes = pack_wav(audio_bytes, sr)
|
||||
# extend: for stream_mode!=normal we still return only audio bytes for backward compatibility
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
|
||||
@ -1133,8 +1127,7 @@ def handle(
|
||||
else:
|
||||
text = cut_text(text, cut_punc)
|
||||
|
||||
return StreamingResponse(
|
||||
get_tts_wav(
|
||||
gen = get_tts_wav(
|
||||
refer_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
@ -1147,9 +1140,10 @@ def handle(
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
),
|
||||
media_type="audio/" + media_type,
|
||||
)
|
||||
# Consume the generator to collect bytes and timestamps; pack as multipart/mixed JSON with audio for compatibility
|
||||
# For simplicity, keep legacy streaming behaviour
|
||||
return StreamingResponse(gen, media_type="audio/" + media_type)
|
||||
|
||||
|
||||
# --------------------------------
|
||||
|
Loading…
x
Reference in New Issue
Block a user