From 338498ad68cb8394f80edadd90d0f32be484da42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E8=8F=9C=E5=B7=A5=E5=8E=821145=E5=8F=B7=E5=91=98?= =?UTF-8?q?=E5=B7=A5?= <114749500+baicai-1145@users.noreply.github.com> Date: Sun, 5 Oct 2025 01:16:14 +0800 Subject: [PATCH] fix --- GPT_SoVITS/inference_webui.py | 35 +++++++++++++++++++++-------------- api.py | 34 +++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 53047cb4..6bd9fb6b 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -1046,12 +1046,15 @@ def get_tts_wav( seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version) norm_text_agg.append(seg_norm) if seg_lang in {"zh", "yue", "ja"} and seg_word2ph: - # char-based + # char-based; also expose as word-level (one char == one word) char_base_idx = len("".join(norm_text_agg[:-1])) for ch_idx, cnt in enumerate(seg_word2ph): global_char_idx = char_base_idx + ch_idx ph_to_char += [global_char_idx] * cnt - ph_to_word += [-1] * len(seg_phones) + token = seg_norm[ch_idx] if ch_idx < len(seg_norm) else "" + word_tokens.append(token) + word_idx = len(word_tokens) - 1 + ph_to_word += [word_idx] * cnt elif seg_lang in {"en", "ko"} and seg_word2ph: # word-based tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)] @@ -1282,23 +1285,27 @@ def get_tts_wav( srt_lines = [] idx_counter = 1 for rec in timestamps_all: - cs = rec.get("char_spans") or [] ws = rec.get("word_spans") or [] - entries = [] - # 统一时间轴:存在则合并后按开始时间排序输出 - for c in cs: - entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]}) - for w in ws: - entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]}) - if entries: - entries.sort(key=lambda x: x["start"]) - for e in entries: + cs = rec.get("char_spans") or [] + + if ws: + for w in ws: srt_lines.append(str(idx_counter)) - srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}") - srt_lines.append(e["text"]) + srt_lines.append(f"{_fmt_srt_time(w['start_s'])} --> { _fmt_srt_time(w['end_s'])}") + srt_lines.append(w.get("word", "")) srt_lines.append("") idx_counter += 1 continue + + if cs: + for c in cs: + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}") + srt_lines.append(c.get("char", "")) + srt_lines.append("") + idx_counter += 1 + continue + # 兜底:整段 st = rec.get("segment_start_s") ed = rec.get("segment_end_s") diff --git a/api.py b/api.py index 6df9ff77..4afae608 100644 --- a/api.py +++ b/api.py @@ -701,7 +701,10 @@ def _build_mixed_mappings_for_api(_text: str, _ui_lang: str, _version: str): for ch_idx, cnt in enumerate(seg_word2ph): global_char_idx = char_base_idx + ch_idx ph_to_char += [global_char_idx] * cnt - ph_to_word += [-1] * len(seg_phones) + token = seg_norm[ch_idx] if ch_idx < len(seg_norm) else "" + word_tokens.append(token) + word_idx = len(word_tokens) - 1 + ph_to_word += [word_idx] * cnt elif seg_lang in {"en", "ko"} and seg_word2ph: tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)] base_word_idx = len(word_tokens) @@ -1772,21 +1775,26 @@ async def tts_json_post(request: Request): srt_lines = [] idx_counter = 1 for rec in timestamps_all: - cs = rec.get("char_spans") or [] ws = rec.get("word_spans") or [] - entries = [] - for c in cs: - entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]}) - for w in ws: - entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]}) - if entries: - entries.sort(key=lambda x: x["start"]) - for e in entries: + cs = rec.get("char_spans") or [] + + if ws: + for w in ws: srt_lines.append(str(idx_counter)) - srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}") - srt_lines.append(e["text"]) + srt_lines.append(f"{_fmt_srt_time(w['start_s'])} --> { _fmt_srt_time(w['end_s'])}") + srt_lines.append(w.get("word", "")) srt_lines.append("") idx_counter += 1 + continue + + if cs: + for c in cs: + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}") + srt_lines.append(c.get("char", "")) + srt_lines.append("") + idx_counter += 1 + continue srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii") timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii") @@ -1805,4 +1813,4 @@ async def tts_json_post(request: Request): if __name__ == "__main__": - uvicorn.run(app, host=host, port=port, workers=1) + uvicorn.run(app, host=host, port=port, workers=1) \ No newline at end of file