This commit is contained in:
白菜工厂1145号员工 2025-10-05 01:16:14 +08:00
parent 19ca3f3f6a
commit 338498ad68
2 changed files with 42 additions and 27 deletions

View File

@ -1046,12 +1046,15 @@ def get_tts_wav(
seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version) seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
norm_text_agg.append(seg_norm) norm_text_agg.append(seg_norm)
if seg_lang in {"zh", "yue", "ja"} and seg_word2ph: if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
# char-based # char-based; also expose as word-level (one char == one word)
char_base_idx = len("".join(norm_text_agg[:-1])) char_base_idx = len("".join(norm_text_agg[:-1]))
for ch_idx, cnt in enumerate(seg_word2ph): for ch_idx, cnt in enumerate(seg_word2ph):
global_char_idx = char_base_idx + ch_idx global_char_idx = char_base_idx + ch_idx
ph_to_char += [global_char_idx] * cnt ph_to_char += [global_char_idx] * cnt
ph_to_word += [-1] * len(seg_phones) token = seg_norm[ch_idx] if ch_idx < len(seg_norm) else ""
word_tokens.append(token)
word_idx = len(word_tokens) - 1
ph_to_word += [word_idx] * cnt
elif seg_lang in {"en", "ko"} and seg_word2ph: elif seg_lang in {"en", "ko"} and seg_word2ph:
# word-based # word-based
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)] tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)]
@ -1282,23 +1285,27 @@ def get_tts_wav(
srt_lines = [] srt_lines = []
idx_counter = 1 idx_counter = 1
for rec in timestamps_all: for rec in timestamps_all:
cs = rec.get("char_spans") or []
ws = rec.get("word_spans") or [] ws = rec.get("word_spans") or []
entries = [] cs = rec.get("char_spans") or []
# 统一时间轴:存在则合并后按开始时间排序输出
for c in cs: if ws:
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]}) for w in ws:
for w in ws:
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
if entries:
entries.sort(key=lambda x: x["start"])
for e in entries:
srt_lines.append(str(idx_counter)) srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}") srt_lines.append(f"{_fmt_srt_time(w['start_s'])} --> { _fmt_srt_time(w['end_s'])}")
srt_lines.append(e["text"]) srt_lines.append(w.get("word", ""))
srt_lines.append("") srt_lines.append("")
idx_counter += 1 idx_counter += 1
continue continue
if cs:
for c in cs:
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}")
srt_lines.append(c.get("char", ""))
srt_lines.append("")
idx_counter += 1
continue
# 兜底:整段 # 兜底:整段
st = rec.get("segment_start_s") st = rec.get("segment_start_s")
ed = rec.get("segment_end_s") ed = rec.get("segment_end_s")

34
api.py
View File

@ -701,7 +701,10 @@ def _build_mixed_mappings_for_api(_text: str, _ui_lang: str, _version: str):
for ch_idx, cnt in enumerate(seg_word2ph): for ch_idx, cnt in enumerate(seg_word2ph):
global_char_idx = char_base_idx + ch_idx global_char_idx = char_base_idx + ch_idx
ph_to_char += [global_char_idx] * cnt ph_to_char += [global_char_idx] * cnt
ph_to_word += [-1] * len(seg_phones) token = seg_norm[ch_idx] if ch_idx < len(seg_norm) else ""
word_tokens.append(token)
word_idx = len(word_tokens) - 1
ph_to_word += [word_idx] * cnt
elif seg_lang in {"en", "ko"} and seg_word2ph: elif seg_lang in {"en", "ko"} and seg_word2ph:
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)] tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)]
base_word_idx = len(word_tokens) base_word_idx = len(word_tokens)
@ -1772,21 +1775,26 @@ async def tts_json_post(request: Request):
srt_lines = [] srt_lines = []
idx_counter = 1 idx_counter = 1
for rec in timestamps_all: for rec in timestamps_all:
cs = rec.get("char_spans") or []
ws = rec.get("word_spans") or [] ws = rec.get("word_spans") or []
entries = [] cs = rec.get("char_spans") or []
for c in cs:
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]}) if ws:
for w in ws: for w in ws:
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
if entries:
entries.sort(key=lambda x: x["start"])
for e in entries:
srt_lines.append(str(idx_counter)) srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}") srt_lines.append(f"{_fmt_srt_time(w['start_s'])} --> { _fmt_srt_time(w['end_s'])}")
srt_lines.append(e["text"]) srt_lines.append(w.get("word", ""))
srt_lines.append("") srt_lines.append("")
idx_counter += 1 idx_counter += 1
continue
if cs:
for c in cs:
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}")
srt_lines.append(c.get("char", ""))
srt_lines.append("")
idx_counter += 1
continue
srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii") srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii")
timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii") timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii")
@ -1805,4 +1813,4 @@ async def tts_json_post(request: Request):
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host=host, port=port, workers=1) uvicorn.run(app, host=host, port=port, workers=1)