From 2ff9a1533ff8a7813c34682bfeb17f3f0f5fe253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E8=8F=9C=E5=B7=A5=E5=8E=821145=E5=8F=B7=E5=91=98?= =?UTF-8?q?=E5=B7=A5?= <114749500+baicai-1145@users.noreply.github.com> Date: Mon, 22 Sep 2025 00:27:14 +0800 Subject: [PATCH] Fix the timestamp processing logic for phonemes, characters, and words --- GPT_SoVITS/inference_webui.py | 149 ++++++++++++++++- api.py | 297 ++++++++++++++++++++++++++++++++++ 2 files changed, 443 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 54d35301..f1cadd76 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -596,6 +596,7 @@ def get_first(text): from text import chinese +import tempfile def get_phones_and_bert(text, language, version, final=False): @@ -850,6 +851,8 @@ def get_tts_wav( phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) timestamps_all = [] + elapsed_s = 0.0 + sr_hz = int(hps.data.sampling_rate) for i_text, text in enumerate(texts): # 解决输入目标文本的空行导致报错的问题 if len(text.strip()) == 0: @@ -974,6 +977,71 @@ def get_tts_wav( }) else: char_spans[-1]["end_s"] = span["end_s"] + # post-merge by char_index across the whole segment to remove jitter fragments + if char_spans: + # group by char_index + groups = {} + for cs in char_spans: + groups.setdefault(cs["char_index"], []).append(cs) + merged = [] + gap_merge_s = 0.08 + # adaptive minimal duration: at least one frame, but not lower than 15ms + min_dur_s = max(0.015, frame_time) + for ci, lst in groups.items(): + lst = sorted(lst, key=lambda x: x["start_s"]) + cur = None + for it in lst: + if cur is None: + cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} + else: + if it["start_s"] - cur["end_s"] <= gap_merge_s: + if it["end_s"] > cur["end_s"]: + cur["end_s"] = it["end_s"] + else: + if cur["end_s"] - cur["start_s"] >= min_dur_s: + merged.append(cur) + cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} + if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s: + merged.append(cur) + # sort merged by time + char_spans = sorted(merged, key=lambda x: x["start_s"]) + # remap normalized chars back to original input text to avoid spurious '.'/'?' + def _build_norm_to_orig_map(orig, norm): + all_punc_local = set(punctuation).union(set(splits)) + mapping = [-1] * len(norm) + o = 0 + n = 0 + while n < len(norm) and o < len(orig): + if norm[n] == orig[o]: + mapping[n] = o + n += 1 + o += 1 + else: + # skip spaces/punctuations on either side + if orig[o].isspace() or orig[o] in all_punc_local: + o += 1 + elif norm[n].isspace() or norm[n] in all_punc_local: + n += 1 + else: + # characters differ (e.g., compatibility form). Advance normalized index. + n += 1 + return mapping + norm2orig = _build_norm_to_orig_map(text, norm_text_seg) + all_punc_local = set(punctuation).union(set(splits)) + remapped = [] + for cs in char_spans: + ci = cs["char_index"] + ch_norm = cs.get("char", "") + oi = norm2orig[ci] if ci < len(norm2orig) else -1 + if oi != -1: + cs["char"] = text[oi] + remapped.append(cs) + else: + # drop normalized-only punctuations/spaces not present in original + if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()): + continue + remapped.append(cs) + char_spans = remapped # word spans word_spans = [] if text_language == "en": @@ -1002,12 +1070,32 @@ def get_tts_wav( }) else: word_spans[-1]["end_s"] = span["end_s"] + # add absolute offsets and record segment timing + audio_len_s = float(audio.shape[0]) / sr_hz + if 'ph_spans' in locals() and ph_spans: + for d in ph_spans: + d["start_s"] += elapsed_s + d["end_s"] += elapsed_s + if char_spans: + for d in char_spans: + d["start_s"] += elapsed_s + d["end_s"] += elapsed_s + if word_spans: + for d in word_spans: + d["start_s"] += elapsed_s + d["end_s"] += elapsed_s + seg_start_s = elapsed_s + seg_end_s = elapsed_s + audio_len_s timestamps_all.append({ "segment_index": i_text, "phoneme_spans": ph_spans, "char_spans": char_spans, "word_spans": word_spans, + "segment_start_s": seg_start_s, + "segment_end_s": seg_end_s, + "text": norm_text_seg, }) + elapsed_s += audio_len_s + float(pause_second) else: refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) @@ -1086,8 +1174,62 @@ def get_tts_wav( audio_opt /= max_audio else: audio_opt = audio_opt.cpu().detach().numpy() - # Return audio and timestamps for UI consumption: ((sr, audio), timestamps) - yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all + # Build SRT file content from timestamps_all + def _fmt_srt_time(t): + h = int(t // 3600) + m = int((t % 3600) // 60) + s = int(t % 60) + ms = int(round((t - int(t)) * 1000)) + return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" + + srt_lines = [] + idx_counter = 1 + for rec in timestamps_all: + # 优先按字分段 + if rec.get("char_spans") and len(rec["char_spans"]): + for c in rec["char_spans"]: + st = c["start_s"] + ed = c["end_s"] + txt = c.get("char", "") + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") + srt_lines.append(txt) + srt_lines.append("") + idx_counter += 1 + continue + # 次选按词(英文) + if rec.get("word_spans") and len(rec["word_spans"]): + for w in rec["word_spans"]: + st = w["start_s"] + ed = w["end_s"] + txt = w.get("word", "") + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") + srt_lines.append(txt) + srt_lines.append("") + idx_counter += 1 + continue + # 最后按整段兜底 + st = rec.get("segment_start_s") + ed = rec.get("segment_end_s") + text_line = rec.get("text", "") + if st is not None and ed is not None: + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") + srt_lines.append(text_line) + srt_lines.append("") + idx_counter += 1 + + srt_path = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=".srt", mode="w", encoding="utf-8") as f: + f.write("\n".join(srt_lines)) + srt_path = f.name + except Exception: + srt_path = None + + # Return audio, timestamps and SRT path for UI + yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path def split(todo_text): @@ -1375,6 +1517,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25) output = gr.Audio(label=i18n("输出的语音"), scale=14) timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)")) + srt_file = gr.File(label=i18n("下载SRT字幕")) inference_button.click( get_tts_wav, @@ -1396,7 +1539,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css if_sr_Checkbox, pause_second_slider, ], - [output, timestamps_box], + [output, timestamps_box, srt_file], ) SoVITS_dropdown.change( change_sovits_weights, diff --git a/api.py b/api.py index 5a2035c1..8a9c36e5 100644 --- a/api.py +++ b/api.py @@ -1385,5 +1385,302 @@ async def tts_endpoint( ) +def _fmt_srt_time(t): + h = int(t // 3600) + m = int((t % 3600) // 60) + s = int(t % 60) + ms = int(round((t - int(t)) * 1000)) + return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" + + +def _build_norm_to_orig_map(orig, norm): + all_punc_local = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"} + mapping = [-1] * len(norm) + o = 0 + n = 0 + while n < len(norm) and o < len(orig): + if norm[n] == orig[o]: + mapping[n] = o + n += 1 + o += 1 + else: + if orig[o].isspace() or orig[o] in all_punc_local: + o += 1 + elif norm[n].isspace() or norm[n] in all_punc_local: + n += 1 + else: + n += 1 + return mapping + + +def synthesize_json( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, +): + infer_sovits = speaker_list["default"].sovits + vq_model = infer_sovits.vq_model + hps = infer_sovits.hps + version = vq_model.version + + if version in {"v3", "v4"}: + return JSONResponse({"code": 400, "message": "v3/v4 暂未提供时间戳JSON接口"}, status_code=400) + + # prepare refer features (same as handle) + dtype = torch.float16 if is_half else torch.float32 + zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half else np.float32) + with torch.no_grad(): + wav16k, sr = librosa.load(refer_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + is_v2pro = version in {"v2Pro", "v2ProPlus"} + refers = [] + if is_v2pro: + sv_emb = [] + if sv_cn_model == None: + init_sv_cn() + spec, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro) + refers = [spec] + if is_v2pro: + sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] + + # text frontend + prompt_language = dict_language[prompt_language.lower()] + text_language = dict_language[text_language.lower()] + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) + texts = text.strip("\n").split("\n") + + # iterate + audio_segments = [] + timestamps_all = [] + sr_hz = int(hps.data.sampling_rate) + elapsed_s = 0.0 + for seg in texts: + if only_punc(seg): + continue + if seg[-1] not in splits: + seg += "。" if text_language != "en" else "." + phones2, bert2, norm_text2 = get_phones_and_bert(seg, text_language, version) + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + with torch.no_grad(): + pred_semantic, idx = speaker_list["default"].gpt.t2s_model.model.infer_panel( + all_phoneme_ids, all_phoneme_len, prompt, bert, top_k=top_k, top_p=top_p, temperature=temperature, early_stop_num=hz * speaker_list["default"].gpt.max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0) + if is_v2pro: + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb + ) + else: + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, phones2_tensor, refers, speed=speed + ) + audio = o[0][0].detach().cpu().numpy() + # timestamps + frame_time = 0.02 / max(float(speed), 1e-6) + ph_spans = [] + if attn is not None: + attn_mean = attn.mean(dim=1)[0] + assign = attn_mean.argmax(dim=-1) + if assign.numel() > 0: + start_f = 0 + cur_ph = int(assign[0].item()) + for f in range(1, assign.shape[0]): + ph = int(assign[f].item()) + if ph != cur_ph: + ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": f * frame_time}) + start_f = f + cur_ph = ph + ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": assign.shape[0] * frame_time}) + + # char aggregation + _, word2ph, norm_text_seg = clean_text_inf(seg, text_language, version) + char_spans = [] + if word2ph: + ph_to_char = [] + for ch_idx, repeat in enumerate(word2ph): + ph_to_char += [ch_idx] * repeat + if ph_spans and ph_to_char: + for span in ph_spans: + ph_idx = span["phoneme_id"] + if 0 <= ph_idx < len(ph_to_char): + char_idx = ph_to_char[ph_idx] + char_spans.append({"char_index": char_idx, "char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", "start_s": span["start_s"], "end_s": span["end_s"]}) + # merge by char_index + if char_spans: + groups = {} + for cs in char_spans: + groups.setdefault(cs["char_index"], []).append(cs) + merged = [] + gap_merge_s = 0.08 + min_dur_s = max(0.015, frame_time) + for ci, lst in groups.items(): + lst = sorted(lst, key=lambda x: x["start_s"]) + cur = None + for it in lst: + if cur is None: + cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} + else: + if it["start_s"] - cur["end_s"] <= gap_merge_s: + if it["end_s"] > cur["end_s"]: + cur["end_s"] = it["end_s"] + else: + if cur["end_s"] - cur["start_s"] >= min_dur_s: + merged.append(cur) + cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} + if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s: + merged.append(cur) + # remap to original input text + norm2orig = _build_norm_to_orig_map(seg, norm_text_seg) + remapped = [] + punc = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"} + for cs in merged: + ci = cs["char_index"] + oi = norm2orig[ci] if ci < len(norm2orig) else -1 + ch_norm = cs.get("char", "") + if oi != -1: + cs["char"] = seg[oi] + remapped.append(cs) + else: + if ch_norm and (ch_norm in punc or ch_norm.isspace()): + continue + remapped.append(cs) + char_spans = sorted(remapped, key=lambda x: x["start_s"]) if remapped else [] + + # offset and record + audio_len_s = float(audio.shape[0]) / sr_hz + for coll in (char_spans,): + for d in coll: + d["start_s"] += elapsed_s + d["end_s"] += elapsed_s + timestamps_all.append({ + "segment_index": len(timestamps_all), + "char_spans": char_spans, + "text": norm_text_seg, + "segment_start_s": elapsed_s, + "segment_end_s": elapsed_s + audio_len_s, + }) + elapsed_s += audio_len_s + 0.3 + audio_segments.append(audio) + + # concatenate audio + if len(audio_segments) == 0: + return JSONResponse({"code": 400, "message": "无有效文本"}, status_code=400) + pad = np.zeros(int(sr_hz * 0.3), dtype=audio_segments[0].dtype) + out = [] + for i, a in enumerate(audio_segments): + out.append(a) + if i < len(audio_segments) - 1: + out.append(pad) + audio_np = np.concatenate(out, 0) + mx = np.abs(audio_np).max() + if mx > 1: + audio_np = audio_np / mx + + # srt + srt_lines = [] + idx_counter = 1 + for rec in timestamps_all: + if rec.get("char_spans"): + for c in rec["char_spans"]: + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}") + srt_lines.append(c.get("char", "")) + srt_lines.append("") + idx_counter += 1 + else: + st = rec.get("segment_start_s") + ed = rec.get("segment_end_s") + if st is not None and ed is not None: + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") + srt_lines.append(rec.get("text", "")) + srt_lines.append("") + idx_counter += 1 + srt_text = "\n".join(srt_lines) + + # pack wav + import base64 + buf = BytesIO() + sf.write(buf, audio_np, sr_hz, format="WAV") + wav_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + return JSONResponse({"code": 0, "sr": sr_hz, "audio_wav_base64": wav_b64, "timestamps": timestamps_all, "srt": srt_text}) + + +@app.post("/tts_json") +async def tts_json_post(request: Request): + body = await request.json() + return synthesize_json( + body.get("refer_wav_path"), + body.get("prompt_text"), + body.get("prompt_language"), + body.get("text"), + body.get("text_language"), + body.get("top_k", 15), + body.get("top_p", 0.6), + body.get("temperature", 0.6), + body.get("speed", 1.0), + body.get("inp_refs", []), + body.get("sample_steps", 32), + body.get("if_sr", False), + ) + + +@app.get("/tts_json") +async def tts_json_get( + refer_wav_path: str, + prompt_text: str, + prompt_language: str, + text: str, + text_language: str, + top_k: int = 15, + top_p: float = 0.6, + temperature: float = 0.6, + speed: float = 1.0, + inp_refs: list = Query(default=[]), + sample_steps: int = 32, + if_sr: bool = False, +): + return synthesize_json( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, + ) + + if __name__ == "__main__": uvicorn.run(app, host=host, port=port, workers=1)