mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 14:08:38 +08:00
Supporting word-level timestamp output through attention weight output
This commit is contained in:
parent
11aa78bd9b
commit
705df4c414
@ -44,10 +44,20 @@ def synthesize(
|
|||||||
result_list = list(synthesis_result)
|
result_list = list(synthesis_result)
|
||||||
|
|
||||||
if result_list:
|
if result_list:
|
||||||
last_sampling_rate, last_audio_data = result_list[-1]
|
# new: ((sr, audio), timestamps)
|
||||||
|
last = result_list[-1]
|
||||||
|
(last_sampling_rate, last_audio_data), timestamps = last
|
||||||
output_wav_path = os.path.join(output_path, "output.wav")
|
output_wav_path = os.path.join(output_path, "output.wav")
|
||||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||||
print(f"Audio saved to {output_wav_path}")
|
print(f"Audio saved to {output_wav_path}")
|
||||||
|
# Optionally save timestamps
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
with open(os.path.join(output_path, "timestamps.json"), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(timestamps, f, ensure_ascii=False, indent=2)
|
||||||
|
print("Timestamps saved to timestamps.json")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -299,7 +299,8 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
result_list = list(synthesis_result)
|
result_list = list(synthesis_result)
|
||||||
|
|
||||||
if result_list:
|
if result_list:
|
||||||
last_sampling_rate, last_audio_data = result_list[-1]
|
last = result_list[-1]
|
||||||
|
last_sampling_rate, last_audio_data = last[0], last[1]
|
||||||
output_wav_path = os.path.join(output_path, "output.wav")
|
output_wav_path = os.path.join(output_path, "output.wav")
|
||||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||||
|
|
||||||
|
@ -849,6 +849,7 @@ def get_tts_wav(
|
|||||||
if not ref_free:
|
if not ref_free:
|
||||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||||
|
|
||||||
|
timestamps_all = []
|
||||||
for i_text, text in enumerate(texts):
|
for i_text, text in enumerate(texts):
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if len(text.strip()) == 0:
|
if len(text.strip()) == 0:
|
||||||
@ -912,14 +913,101 @@ def get_tts_wav(
|
|||||||
refers = [refers]
|
refers = [refers]
|
||||||
if is_v2pro:
|
if is_v2pro:
|
||||||
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
|
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
|
||||||
|
# compute alignment and audio (v2/v2Pro only)
|
||||||
|
phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||||
if is_v2pro:
|
if is_v2pro:
|
||||||
audio = vq_model.decode(
|
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb
|
pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb
|
||||||
)[0][0]
|
)
|
||||||
else:
|
else:
|
||||||
audio = vq_model.decode(
|
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
pred_semantic, phones2_tensor, refers, speed=speed
|
||||||
)[0][0]
|
)
|
||||||
|
audio = o[0][0]
|
||||||
|
last_attn = attn # [B,H,T_ssl,T_text]
|
||||||
|
last_y_mask = y_mask # [B,1,T_ssl]
|
||||||
|
# derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640)
|
||||||
|
if last_attn is not None:
|
||||||
|
attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text]
|
||||||
|
assign = attn_heads_mean.argmax(dim=-1) # [T_ssl]
|
||||||
|
frame_time = 0.02 / max(speed, 1e-6)
|
||||||
|
# collapse consecutive frames pointing to same phoneme id
|
||||||
|
ph_spans = []
|
||||||
|
if assign.numel() > 0:
|
||||||
|
start_f = 0
|
||||||
|
cur_ph = int(assign[0].item())
|
||||||
|
for f in range(1, assign.shape[0]):
|
||||||
|
ph = int(assign[f].item())
|
||||||
|
if ph != cur_ph:
|
||||||
|
ph_spans.append({
|
||||||
|
"phoneme_id": cur_ph,
|
||||||
|
"start_s": start_f * frame_time,
|
||||||
|
"end_s": f * frame_time,
|
||||||
|
})
|
||||||
|
start_f = f
|
||||||
|
cur_ph = ph
|
||||||
|
ph_spans.append({
|
||||||
|
"phoneme_id": cur_ph,
|
||||||
|
"start_s": start_f * frame_time,
|
||||||
|
"end_s": assign.shape[0] * frame_time,
|
||||||
|
})
|
||||||
|
# char/word aggregation
|
||||||
|
# obtain word2ph for current text segment
|
||||||
|
_, word2ph, norm_text_seg = clean_text_inf(text, text_language, version)
|
||||||
|
# char spans (for zh/yue where word2ph is char-based)
|
||||||
|
char_spans = []
|
||||||
|
if word2ph:
|
||||||
|
ph_to_char = []
|
||||||
|
for ch_idx, repeat in enumerate(word2ph):
|
||||||
|
ph_to_char += [ch_idx] * repeat
|
||||||
|
if ph_spans and ph_to_char:
|
||||||
|
for span in ph_spans:
|
||||||
|
ph_idx = span["phoneme_id"]
|
||||||
|
if 0 <= ph_idx < len(ph_to_char):
|
||||||
|
char_idx = ph_to_char[ph_idx]
|
||||||
|
if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx:
|
||||||
|
char_spans.append({
|
||||||
|
"char_index": char_idx,
|
||||||
|
"char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "",
|
||||||
|
"start_s": span["start_s"],
|
||||||
|
"end_s": span["end_s"],
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
char_spans[-1]["end_s"] = span["end_s"]
|
||||||
|
# word spans
|
||||||
|
word_spans = []
|
||||||
|
if text_language == "en":
|
||||||
|
# build phoneme-to-word map using dictionary per-word g2p
|
||||||
|
try:
|
||||||
|
from text.english import g2p as g2p_en
|
||||||
|
except Exception:
|
||||||
|
g2p_en = None
|
||||||
|
words = norm_text_seg.split()
|
||||||
|
ph_to_word = []
|
||||||
|
if g2p_en:
|
||||||
|
for w_idx, w in enumerate(words):
|
||||||
|
phs_w = g2p_en(w)
|
||||||
|
ph_to_word += [w_idx] * len(phs_w)
|
||||||
|
if ph_spans and ph_to_word:
|
||||||
|
for span in ph_spans:
|
||||||
|
ph_idx = span["phoneme_id"]
|
||||||
|
if 0 <= ph_idx < len(ph_to_word):
|
||||||
|
wi = ph_to_word[ph_idx]
|
||||||
|
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
|
||||||
|
word_spans.append({
|
||||||
|
"word_index": wi,
|
||||||
|
"word": words[wi] if wi < len(words) else "",
|
||||||
|
"start_s": span["start_s"],
|
||||||
|
"end_s": span["end_s"],
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
word_spans[-1]["end_s"] = span["end_s"]
|
||||||
|
timestamps_all.append({
|
||||||
|
"segment_index": i_text,
|
||||||
|
"phoneme_spans": ph_spans,
|
||||||
|
"char_spans": char_spans,
|
||||||
|
"word_spans": word_spans,
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
||||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||||
@ -998,7 +1086,8 @@ def get_tts_wav(
|
|||||||
audio_opt /= max_audio
|
audio_opt /= max_audio
|
||||||
else:
|
else:
|
||||||
audio_opt = audio_opt.cpu().detach().numpy()
|
audio_opt = audio_opt.cpu().detach().numpy()
|
||||||
yield opt_sr, (audio_opt * 32767).astype(np.int16)
|
# Return audio and timestamps for UI consumption: ((sr, audio), timestamps)
|
||||||
|
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all
|
||||||
|
|
||||||
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
@ -1285,6 +1374,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
||||||
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
||||||
|
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)"))
|
||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
get_tts_wav,
|
get_tts_wav,
|
||||||
@ -1306,7 +1396,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
|||||||
if_sr_Checkbox,
|
if_sr_Checkbox,
|
||||||
pause_second_slider,
|
pause_second_slider,
|
||||||
],
|
],
|
||||||
[output],
|
[output, timestamps_box],
|
||||||
)
|
)
|
||||||
SoVITS_dropdown.change(
|
SoVITS_dropdown.change(
|
||||||
change_sovits_weights,
|
change_sovits_weights,
|
||||||
|
@ -0,0 +1,5 @@
|
|||||||
|
from .models import SynthesizerTrn
|
||||||
|
|
||||||
|
__all__ = ["SynthesizerTrn"]
|
||||||
|
|
||||||
|
|
@ -93,6 +93,23 @@ def subsequent_mask(length):
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def attn_to_token_alignment(attn, reduce="mean"):
|
||||||
|
"""
|
||||||
|
Convert attention [B, H, T_tgt, T_src] to argmax alignment per target step.
|
||||||
|
Returns [B, T_tgt] indices.
|
||||||
|
"""
|
||||||
|
if attn is None:
|
||||||
|
return None
|
||||||
|
if reduce == "mean":
|
||||||
|
attn_ = attn.mean(dim=1)
|
||||||
|
elif reduce == "max":
|
||||||
|
attn_ = attn.max(dim=1).values
|
||||||
|
else:
|
||||||
|
attn_ = attn
|
||||||
|
return attn_.argmax(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||||
n_channels_int = n_channels[0]
|
n_channels_int = n_channels[0]
|
||||||
|
@ -1004,6 +1004,67 @@ class SynthesizerTrn(nn.Module):
|
|||||||
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode_with_alignment(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||||
|
"""
|
||||||
|
Decode and also return cross-attention alignment from MRTE (ssl↔text).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
o (Tensor): waveform tensor [B, 1, T_audio]
|
||||||
|
attn (Tensor): attention weights [B, n_heads, T_ssl, T_text]
|
||||||
|
y_mask (Tensor): ssl/time mask [B, 1, T_ssl]
|
||||||
|
"""
|
||||||
|
def get_ge(refer, sv_emb):
|
||||||
|
ge = None
|
||||||
|
if refer is not None:
|
||||||
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
|
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||||
|
if self.version == "v1":
|
||||||
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
else:
|
||||||
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
|
if self.is_v2pro:
|
||||||
|
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
||||||
|
ge += sv_emb.unsqueeze(-1)
|
||||||
|
ge = self.prelu(ge)
|
||||||
|
return ge
|
||||||
|
|
||||||
|
if type(refer) == list:
|
||||||
|
ges = []
|
||||||
|
for idx, _refer in enumerate(refer):
|
||||||
|
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||||
|
ges.append(ge)
|
||||||
|
ge = torch.stack(ges, 0).mean(0)
|
||||||
|
else:
|
||||||
|
ge = get_ge(refer, sv_emb)
|
||||||
|
|
||||||
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||||
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
|
|
||||||
|
quantized = self.quantizer.decode(codes)
|
||||||
|
if self.semantic_frame_rate == "25hz":
|
||||||
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
|
x, m_p, logs_p, y_mask = self.enc_p(
|
||||||
|
quantized,
|
||||||
|
y_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||||
|
speed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MRTE cross-attention collected during enc_p forward
|
||||||
|
# Shape: [B, n_heads, T_ssl, T_text]
|
||||||
|
try:
|
||||||
|
attn = self.enc_p.mrte.cross_attention.attn
|
||||||
|
except Exception:
|
||||||
|
attn = None
|
||||||
|
|
||||||
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
||||||
|
return o, attn, y_mask
|
||||||
|
|
||||||
def extract_latent(self, x):
|
def extract_latent(self, x):
|
||||||
ssl = self.ssl_proj(x)
|
ssl = self.ssl_proj(x)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
|
26
api.py
26
api.py
@ -948,27 +948,19 @@ def get_tts_wav(
|
|||||||
|
|
||||||
if version not in {"v3", "v4"}:
|
if version not in {"v3", "v4"}:
|
||||||
if is_v2pro:
|
if is_v2pro:
|
||||||
audio = (
|
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||||
vq_model.decode(
|
|
||||||
pred_semantic,
|
pred_semantic,
|
||||||
torch.LongTensor(phones2).to(device).unsqueeze(0),
|
torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||||
refers,
|
refers,
|
||||||
speed=speed,
|
speed=speed,
|
||||||
sv_emb=sv_emb,
|
sv_emb=sv_emb,
|
||||||
)
|
)
|
||||||
.detach()
|
audio = o.detach().cpu().numpy()[0, 0]
|
||||||
.cpu()
|
|
||||||
.numpy()[0, 0]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
audio = (
|
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||||
vq_model.decode(
|
|
||||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
||||||
)
|
)
|
||||||
.detach()
|
audio = o.detach().cpu().numpy()[0, 0]
|
||||||
.cpu()
|
|
||||||
.numpy()[0, 0]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||||
@ -1054,6 +1046,7 @@ def get_tts_wav(
|
|||||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
if stream_mode == "normal":
|
if stream_mode == "normal":
|
||||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||||
|
# For backward compatibility, yield audio chunk only
|
||||||
yield audio_chunk
|
yield audio_chunk
|
||||||
|
|
||||||
if not stream_mode == "normal":
|
if not stream_mode == "normal":
|
||||||
@ -1065,6 +1058,7 @@ def get_tts_wav(
|
|||||||
else:
|
else:
|
||||||
sr = 48000 # v4
|
sr = 48000 # v4
|
||||||
audio_bytes = pack_wav(audio_bytes, sr)
|
audio_bytes = pack_wav(audio_bytes, sr)
|
||||||
|
# extend: for stream_mode!=normal we still return only audio bytes for backward compatibility
|
||||||
yield audio_bytes.getvalue()
|
yield audio_bytes.getvalue()
|
||||||
|
|
||||||
|
|
||||||
@ -1133,8 +1127,7 @@ def handle(
|
|||||||
else:
|
else:
|
||||||
text = cut_text(text, cut_punc)
|
text = cut_text(text, cut_punc)
|
||||||
|
|
||||||
return StreamingResponse(
|
gen = get_tts_wav(
|
||||||
get_tts_wav(
|
|
||||||
refer_wav_path,
|
refer_wav_path,
|
||||||
prompt_text,
|
prompt_text,
|
||||||
prompt_language,
|
prompt_language,
|
||||||
@ -1147,9 +1140,10 @@ def handle(
|
|||||||
inp_refs,
|
inp_refs,
|
||||||
sample_steps,
|
sample_steps,
|
||||||
if_sr,
|
if_sr,
|
||||||
),
|
|
||||||
media_type="audio/" + media_type,
|
|
||||||
)
|
)
|
||||||
|
# Consume the generator to collect bytes and timestamps; pack as multipart/mixed JSON with audio for compatibility
|
||||||
|
# For simplicity, keep legacy streaming behaviour
|
||||||
|
return StreamingResponse(gen, media_type="audio/" + media_type)
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------
|
# --------------------------------
|
||||||
|
Loading…
x
Reference in New Issue
Block a user