mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-15 21:26:51 +08:00
fix
This commit is contained in:
parent
19ca3f3f6a
commit
338498ad68
@ -1046,12 +1046,15 @@ def get_tts_wav(
|
|||||||
seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
|
seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
|
||||||
norm_text_agg.append(seg_norm)
|
norm_text_agg.append(seg_norm)
|
||||||
if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
|
if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
|
||||||
# char-based
|
# char-based; also expose as word-level (one char == one word)
|
||||||
char_base_idx = len("".join(norm_text_agg[:-1]))
|
char_base_idx = len("".join(norm_text_agg[:-1]))
|
||||||
for ch_idx, cnt in enumerate(seg_word2ph):
|
for ch_idx, cnt in enumerate(seg_word2ph):
|
||||||
global_char_idx = char_base_idx + ch_idx
|
global_char_idx = char_base_idx + ch_idx
|
||||||
ph_to_char += [global_char_idx] * cnt
|
ph_to_char += [global_char_idx] * cnt
|
||||||
ph_to_word += [-1] * len(seg_phones)
|
token = seg_norm[ch_idx] if ch_idx < len(seg_norm) else ""
|
||||||
|
word_tokens.append(token)
|
||||||
|
word_idx = len(word_tokens) - 1
|
||||||
|
ph_to_word += [word_idx] * cnt
|
||||||
elif seg_lang in {"en", "ko"} and seg_word2ph:
|
elif seg_lang in {"en", "ko"} and seg_word2ph:
|
||||||
# word-based
|
# word-based
|
||||||
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)]
|
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)]
|
||||||
@ -1282,23 +1285,27 @@ def get_tts_wav(
|
|||||||
srt_lines = []
|
srt_lines = []
|
||||||
idx_counter = 1
|
idx_counter = 1
|
||||||
for rec in timestamps_all:
|
for rec in timestamps_all:
|
||||||
cs = rec.get("char_spans") or []
|
|
||||||
ws = rec.get("word_spans") or []
|
ws = rec.get("word_spans") or []
|
||||||
entries = []
|
cs = rec.get("char_spans") or []
|
||||||
# 统一时间轴:存在则合并后按开始时间排序输出
|
|
||||||
for c in cs:
|
if ws:
|
||||||
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]})
|
for w in ws:
|
||||||
for w in ws:
|
|
||||||
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
|
|
||||||
if entries:
|
|
||||||
entries.sort(key=lambda x: x["start"])
|
|
||||||
for e in entries:
|
|
||||||
srt_lines.append(str(idx_counter))
|
srt_lines.append(str(idx_counter))
|
||||||
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}")
|
srt_lines.append(f"{_fmt_srt_time(w['start_s'])} --> { _fmt_srt_time(w['end_s'])}")
|
||||||
srt_lines.append(e["text"])
|
srt_lines.append(w.get("word", ""))
|
||||||
srt_lines.append("")
|
srt_lines.append("")
|
||||||
idx_counter += 1
|
idx_counter += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if cs:
|
||||||
|
for c in cs:
|
||||||
|
srt_lines.append(str(idx_counter))
|
||||||
|
srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}")
|
||||||
|
srt_lines.append(c.get("char", ""))
|
||||||
|
srt_lines.append("")
|
||||||
|
idx_counter += 1
|
||||||
|
continue
|
||||||
|
|
||||||
# 兜底:整段
|
# 兜底:整段
|
||||||
st = rec.get("segment_start_s")
|
st = rec.get("segment_start_s")
|
||||||
ed = rec.get("segment_end_s")
|
ed = rec.get("segment_end_s")
|
||||||
|
34
api.py
34
api.py
@ -701,7 +701,10 @@ def _build_mixed_mappings_for_api(_text: str, _ui_lang: str, _version: str):
|
|||||||
for ch_idx, cnt in enumerate(seg_word2ph):
|
for ch_idx, cnt in enumerate(seg_word2ph):
|
||||||
global_char_idx = char_base_idx + ch_idx
|
global_char_idx = char_base_idx + ch_idx
|
||||||
ph_to_char += [global_char_idx] * cnt
|
ph_to_char += [global_char_idx] * cnt
|
||||||
ph_to_word += [-1] * len(seg_phones)
|
token = seg_norm[ch_idx] if ch_idx < len(seg_norm) else ""
|
||||||
|
word_tokens.append(token)
|
||||||
|
word_idx = len(word_tokens) - 1
|
||||||
|
ph_to_word += [word_idx] * cnt
|
||||||
elif seg_lang in {"en", "ko"} and seg_word2ph:
|
elif seg_lang in {"en", "ko"} and seg_word2ph:
|
||||||
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)]
|
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)]
|
||||||
base_word_idx = len(word_tokens)
|
base_word_idx = len(word_tokens)
|
||||||
@ -1772,21 +1775,26 @@ async def tts_json_post(request: Request):
|
|||||||
srt_lines = []
|
srt_lines = []
|
||||||
idx_counter = 1
|
idx_counter = 1
|
||||||
for rec in timestamps_all:
|
for rec in timestamps_all:
|
||||||
cs = rec.get("char_spans") or []
|
|
||||||
ws = rec.get("word_spans") or []
|
ws = rec.get("word_spans") or []
|
||||||
entries = []
|
cs = rec.get("char_spans") or []
|
||||||
for c in cs:
|
|
||||||
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]})
|
if ws:
|
||||||
for w in ws:
|
for w in ws:
|
||||||
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
|
|
||||||
if entries:
|
|
||||||
entries.sort(key=lambda x: x["start"])
|
|
||||||
for e in entries:
|
|
||||||
srt_lines.append(str(idx_counter))
|
srt_lines.append(str(idx_counter))
|
||||||
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}")
|
srt_lines.append(f"{_fmt_srt_time(w['start_s'])} --> { _fmt_srt_time(w['end_s'])}")
|
||||||
srt_lines.append(e["text"])
|
srt_lines.append(w.get("word", ""))
|
||||||
srt_lines.append("")
|
srt_lines.append("")
|
||||||
idx_counter += 1
|
idx_counter += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cs:
|
||||||
|
for c in cs:
|
||||||
|
srt_lines.append(str(idx_counter))
|
||||||
|
srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}")
|
||||||
|
srt_lines.append(c.get("char", ""))
|
||||||
|
srt_lines.append("")
|
||||||
|
idx_counter += 1
|
||||||
|
continue
|
||||||
srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii")
|
srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii")
|
||||||
|
|
||||||
timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii")
|
timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii")
|
||||||
@ -1805,4 +1813,4 @@ async def tts_json_post(request: Request):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host=host, port=port, workers=1)
|
uvicorn.run(app, host=host, port=port, workers=1)
|
Loading…
x
Reference in New Issue
Block a user