Supports phoneme and word-level timestamp output for multilingual text

This commit is contained in:
白菜工厂1145号员工 2025-09-22 07:43:59 +08:00
parent 2ff9a1533f
commit 19ca3f3f6a
6 changed files with 2473 additions and 2169 deletions

View File

@ -932,7 +932,40 @@ def get_tts_wav(
# derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640)
if last_attn is not None:
attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text]
assign = attn_heads_mean.argmax(dim=-1) # [T_ssl]
# Viterbi monotonic alignment (stay or advance by 1)
def _viterbi_monotonic(p):
T, N = p.shape
if T < N:
return p.argmax(dim=-1)
eps = 1e-8
cost = -torch.log(p + eps)
dp = torch.empty((T, N), dtype=cost.dtype, device=cost.device)
prev = torch.zeros((T, N), dtype=torch.uint8, device=cost.device)
dp[0, 0] = cost[0, 0]
if N > 1:
dp[0, 1:] = float("inf")
for t_i in range(1, T):
# j=0: only stay
dp[t_i, 0] = dp[t_i - 1, 0] + cost[t_i, 0]
prev[t_i, 0] = 0
if N > 1:
stay = dp[t_i - 1, 1:] + cost[t_i, 1:]
move = dp[t_i - 1, :-1] + cost[t_i, 1:]
better_move = move < stay
dp[t_i, 1:] = torch.where(better_move, move, stay)
prev[t_i, 1:] = better_move.to(torch.uint8)
# backtrack from (T-1, N-1)
j = N - 1
assign_bt = torch.empty(T, dtype=torch.long, device=cost.device)
for t_i in range(T - 1, -1, -1):
assign_bt[t_i] = j
if t_i > 0 and prev[t_i, j] == 1:
j = j - 1
if j < 0:
j = 0
return assign_bt
assign = _viterbi_monotonic(attn_heads_mean)
frame_time = 0.02 / max(speed, 1e-6)
# collapse consecutive frames pointing to same phoneme id
ph_spans = []
@ -954,122 +987,186 @@ def get_tts_wav(
"start_s": start_f * frame_time,
"end_s": assign.shape[0] * frame_time,
})
# char/word aggregation
# obtain word2ph for current text segment
_, word2ph, norm_text_seg = clean_text_inf(text, text_language, version)
# char spans (for zh/yue where word2ph is char-based)
char_spans = []
if word2ph:
ph_to_char = []
for ch_idx, repeat in enumerate(word2ph):
ph_to_char += [ch_idx] * repeat
if ph_spans and ph_to_char:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_char):
char_idx = ph_to_char[ph_idx]
if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx:
char_spans.append({
"char_index": char_idx,
"char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
char_spans[-1]["end_s"] = span["end_s"]
# post-merge by char_index across the whole segment to remove jitter fragments
if char_spans:
# group by char_index
groups = {}
for cs in char_spans:
groups.setdefault(cs["char_index"], []).append(cs)
merged = []
gap_merge_s = 0.08
# adaptive minimal duration: at least one frame, but not lower than 15ms
min_dur_s = max(0.015, frame_time)
for ci, lst in groups.items():
lst = sorted(lst, key=lambda x: x["start_s"])
cur = None
for it in lst:
if cur is None:
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
else:
if it["start_s"] - cur["end_s"] <= gap_merge_s:
if it["end_s"] > cur["end_s"]:
cur["end_s"] = it["end_s"]
else:
if cur["end_s"] - cur["start_s"] >= min_dur_s:
merged.append(cur)
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
merged.append(cur)
# sort merged by time
char_spans = sorted(merged, key=lambda x: x["start_s"])
# remap normalized chars back to original input text to avoid spurious '.'/'?'
def _build_norm_to_orig_map(orig, norm):
all_punc_local = set(punctuation).union(set(splits))
mapping = [-1] * len(norm)
o = 0
n = 0
while n < len(norm) and o < len(orig):
if norm[n] == orig[o]:
mapping[n] = o
n += 1
o += 1
else:
# skip spaces/punctuations on either side
if orig[o].isspace() or orig[o] in all_punc_local:
o += 1
elif norm[n].isspace() or norm[n] in all_punc_local:
n += 1
else:
# characters differ (e.g., compatibility form). Advance normalized index.
n += 1
return mapping
norm2orig = _build_norm_to_orig_map(text, norm_text_seg)
all_punc_local = set(punctuation).union(set(splits))
remapped = []
for cs in char_spans:
ci = cs["char_index"]
ch_norm = cs.get("char", "")
oi = norm2orig[ci] if ci < len(norm2orig) else -1
if oi != -1:
cs["char"] = text[oi]
remapped.append(cs)
else:
# drop normalized-only punctuations/spaces not present in original
if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()):
# char/word aggregation for multi-language text
def _build_mixed_mappings(_text, _ui_lang, _version):
# replicate segmentation logic from get_phones_and_bert
_text = re.sub(r' {2,}', ' ', _text)
textlist = []
langlist = []
if _ui_lang == "all_zh":
for tmp in LangSegmenter.getTexts(_text, "zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "all_yue":
for tmp in LangSegmenter.getTexts(_text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "all_ja":
for tmp in LangSegmenter.getTexts(_text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "all_ko":
for tmp in LangSegmenter.getTexts(_text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "en":
langlist.append("en")
textlist.append(_text)
elif _ui_lang == "auto":
for tmp in LangSegmenter.getTexts(_text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "auto_yue":
for tmp in LangSegmenter.getTexts(_text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(_text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
remapped.append(cs)
char_spans = remapped
# word spans
word_spans = []
if text_language == "en":
# build phoneme-to-word map using dictionary per-word g2p
try:
from text.english import g2p as g2p_en
except Exception:
g2p_en = None
words = norm_text_seg.split()
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
langlist.append(_ui_lang)
textlist.append(tmp["text"])
# aggregate mappings
ph_to_char = []
ph_to_word = []
if g2p_en:
for w_idx, w in enumerate(words):
phs_w = g2p_en(w)
ph_to_word += [w_idx] * len(phs_w)
if ph_spans and ph_to_word:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_word):
wi = ph_to_word[ph_idx]
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
word_spans.append({
"word_index": wi,
"word": words[wi] if wi < len(words) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
word_tokens = []
norm_text_agg = []
import re as _re
for seg_text, seg_lang in zip(textlist, langlist):
seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
norm_text_agg.append(seg_norm)
if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
# char-based
char_base_idx = len("".join(norm_text_agg[:-1]))
for ch_idx, cnt in enumerate(seg_word2ph):
global_char_idx = char_base_idx + ch_idx
ph_to_char += [global_char_idx] * cnt
ph_to_word += [-1] * len(seg_phones)
elif seg_lang in {"en", "ko"} and seg_word2ph:
# word-based
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)]
base_word_idx = len(word_tokens)
t_idx = 0
for cnt in seg_word2ph:
if t_idx < len(tokens_seg):
ph_to_word += [base_word_idx + t_idx] * cnt
t_idx += 1
ph_to_char += [-1] * len(seg_phones)
word_tokens.extend(tokens_seg)
else:
# unknown mapping; fill with -1
ph_to_char += [-1] * len(seg_phones)
ph_to_word += [-1] * len(seg_phones)
norm_text_agg = "".join(norm_text_agg)
return norm_text_agg, ph_to_char, ph_to_word, word_tokens
norm_text_seg, ph_to_char_map, ph_to_word_map, word_tokens = _build_mixed_mappings(text, text_language, version)
# char spans (for zh/yue/ja parts only)
char_spans = []
if ph_spans and ph_to_char_map and len(ph_to_char_map) >= 1:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_char_map):
ci = ph_to_char_map[ph_idx]
if ci == -1:
continue
if len(char_spans) == 0 or char_spans[-1]["char_index"] != ci:
char_spans.append({
"char_index": ci,
"char": norm_text_seg[ci] if ci < len(norm_text_seg) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
char_spans[-1]["end_s"] = span["end_s"]
# post-merge and remap for chars
if char_spans:
groups = {}
for cs in char_spans:
groups.setdefault(cs["char_index"], []).append(cs)
merged = []
gap_merge_s = 0.08
min_dur_s = max(0.015, frame_time)
for ci, lst in groups.items():
lst = sorted(lst, key=lambda x: x["start_s"])
cur = None
for it in lst:
if cur is None:
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
else:
if it["start_s"] - cur["end_s"] <= gap_merge_s:
if it["end_s"] > cur["end_s"]:
cur["end_s"] = it["end_s"]
else:
word_spans[-1]["end_s"] = span["end_s"]
if cur["end_s"] - cur["start_s"] >= min_dur_s:
merged.append(cur)
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
merged.append(cur)
char_spans = sorted(merged, key=lambda x: x["start_s"])
def _build_norm_to_orig_map(orig, norm):
all_punc_local = set(punctuation).union(set(splits))
mapping = [-1] * len(norm)
o = 0
n = 0
while n < len(norm) and o < len(orig):
if norm[n] == orig[o]:
mapping[n] = o
n += 1
o += 1
else:
if orig[o].isspace() or orig[o] in all_punc_local:
o += 1
elif norm[n].isspace() or norm[n] in all_punc_local:
n += 1
else:
n += 1
return mapping
norm2orig = _build_norm_to_orig_map(text, norm_text_seg)
all_punc_local = set(punctuation).union(set(splits))
remapped = []
for cs in char_spans:
ci = cs["char_index"]
ch_norm = cs.get("char", "")
oi = norm2orig[ci] if ci < len(norm2orig) else -1
if oi != -1:
cs["char"] = text[oi]
remapped.append(cs)
else:
if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()):
continue
remapped.append(cs)
char_spans = remapped
# word spans (for en/ko parts only)
word_spans = []
if ph_spans and ph_to_word_map and len(ph_to_word_map) >= 1:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_word_map):
wi = ph_to_word_map[ph_idx]
if wi == -1:
continue
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
word_spans.append({
"word_index": wi,
"word": word_tokens[wi] if wi < len(word_tokens) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
word_spans[-1]["end_s"] = span["end_s"]
# add absolute offsets and record segment timing
audio_len_s = float(audio.shape[0]) / sr_hz
if 'ph_spans' in locals() and ph_spans:
@ -1185,31 +1282,24 @@ def get_tts_wav(
srt_lines = []
idx_counter = 1
for rec in timestamps_all:
# 优先按字分段
if rec.get("char_spans") and len(rec["char_spans"]):
for c in rec["char_spans"]:
st = c["start_s"]
ed = c["end_s"]
txt = c.get("char", "")
cs = rec.get("char_spans") or []
ws = rec.get("word_spans") or []
entries = []
# 统一时间轴:存在则合并后按开始时间排序输出
for c in cs:
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]})
for w in ws:
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
if entries:
entries.sort(key=lambda x: x["start"])
for e in entries:
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
srt_lines.append(txt)
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}")
srt_lines.append(e["text"])
srt_lines.append("")
idx_counter += 1
continue
# 次选按词(英文)
if rec.get("word_spans") and len(rec["word_spans"]):
for w in rec["word_spans"]:
st = w["start_s"]
ed = w["end_s"]
txt = w.get("word", "")
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
srt_lines.append(txt)
srt_lines.append("")
idx_counter += 1
continue
# 最后按整段兜底
# 兜底:整段
st = rec.get("segment_start_s")
ed = rec.get("segment_end_s")
text_line = rec.get("text", "")
@ -1228,8 +1318,17 @@ def get_tts_wav(
except Exception:
srt_path = None
# Return audio, timestamps and SRT path for UI
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path
# Also write JSON timestamps to temp file for download
json_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w", encoding="utf-8") as fjson:
json.dump(timestamps_all, fjson, ensure_ascii=False, indent=2)
json_path = fjson.name
except Exception:
json_path = None
# Return audio, timestamps, SRT path and JSON path for UI
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path, json_path
def split(todo_text):
@ -1516,8 +1615,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Row():
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
output = gr.Audio(label=i18n("输出的语音"), scale=14)
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)"))
with gr.Accordion(i18n("时间戳(音素/字/词) - 点击展开"), open=False):
timestamps_box = gr.JSON(label=i18n("时间戳JSON预览"))
srt_file = gr.File(label=i18n("下载SRT字幕"))
json_file = gr.File(label=i18n("下载时间戳JSON"))
inference_button.click(
get_tts_wav,
@ -1539,7 +1640,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
if_sr_Checkbox,
pause_second_slider,
],
[output, timestamps_box, srt_file],
[output, timestamps_box, srt_file, json_file],
)
SoVITS_dropdown.change(
change_sovits_weights,

View File

@ -39,18 +39,24 @@ def clean_text(text, language, version=None):
norm_text = language_module.text_normalize(text)
else:
norm_text = text
if language == "zh" or language == "yue": ##########
phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph)
elif language == "en":
phones = language_module.g2p(norm_text)
if len(phones) < 4:
phones = [","] + phones
word2ph = None
else:
phones = language_module.g2p(norm_text)
word2ph = None
# Try per-language word2ph helpers
if hasattr(language_module, "g2p_with_word2ph"):
try:
phones, word2ph = language_module.g2p_with_word2ph(norm_text, keep_punc=False)
except Exception:
phones = language_module.g2p(norm_text)
word2ph = None
else:
phones = language_module.g2p(norm_text)
word2ph = None
if language == "en" and len(phones) < 4:
phones = [","] + phones
phones = ["UNK" if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text

View File

@ -368,6 +368,32 @@ def g2p(text):
return replace_phs(phones)
def g2p_with_word2ph(text, keep_punc=False):
"""
Returns (phones, word2ph) for English by whitespace tokenization.
- Tokenize by spaces; for each token, call g2p(token)
- word2ph per token = max(1, num_valid_phones) if keep_punc else skip pure punctuations
"""
tokens = re.split(r"(\s+)", text)
phones_all = []
word2ph = []
punc_set = set(punctuation)
for tok in tokens:
if tok.strip() == "":
if keep_punc:
word2ph.append(1)
continue
if all((c in punc_set or c.isspace()) for c in tok):
if keep_punc:
word2ph.append(1)
continue
phs = g2p(tok)
phs_valid = [p for p in phs if p not in punc_set]
phones_all.extend(phs_valid)
word2ph.append(max(1, len(phs_valid)))
return phones_all, word2ph
if __name__ == "__main__":
print(g2p("hello"))
print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture.")))

View File

@ -271,6 +271,30 @@ def g2p(norm_text, with_prosody=True):
return phones
# Helper for alignment: build phones and word2ph by per-character g2p (ignoring prosody markers)
def g2p_with_word2ph(text, keep_punc=False):
"""
Returns (phones, word2ph)
- Per-character g2p; ignore prosody markers like '[', ']','^', '$', '#', '_'
- Punctuation counted as 1 if keep_punc else skipped
"""
norm_text = text_normalize(text)
phones_all = []
word2ph = []
prosody_markers = {'[', ']', '^', '$', '#', '_'}
punc_set = set(punctuation)
for ch in norm_text:
if ch.isspace() or ch in punc_set:
if keep_punc:
word2ph.append(1)
continue
phs = preprocess_jap(ch, with_prosody=True)
phs = [post_replace_ph(p) for p in phs if p not in prosody_markers and p not in punc_set]
phones_all.extend(phs)
word2ph.append(max(1, len(phs)))
return phones_all, word2ph
if __name__ == "__main__":
phones = g2p("Hello.こんにちは今日もNiCe天気ですねtokyotowerに行きましょう")
print(phones)

View File

@ -332,6 +332,31 @@ def g2p(text):
return text
# Helper for alignment: build phones and word2ph by space-tokenizing Korean text
def g2p_with_word2ph(text, keep_punc=False):
"""
Returns (phones, word2ph)
- Tokenize by spaces; for each non-space token, call g2p(token)
- Exclude separator/placeholders from phones: '.', '', '' and common punctuations
- word2ph per token = max(1, num_valid_phones) if keep_punc else skip spaces entirely
"""
# prepare
tokens = re.split(r"(\s+)", text)
phones_all = []
word2ph = []
punc_set = {",", ".", "!", "?", "", "", "", ":", ""}
for tok in tokens:
if tok.strip() == "":
if keep_punc:
word2ph.append(1)
continue
phs = g2p(tok)
# filter non-phonetic markers
phs_valid = [p for p in phs if p not in {'.', '', ''} and p not in punc_set]
phones_all.extend(phs_valid)
word2ph.append(max(1, len(phs_valid)))
return phones_all, word2ph
if __name__ == "__main__":
text = "안녕하세요"
print(g2p(text))

650
api.py
View File

@ -141,6 +141,7 @@ RESP: 无
"""
import argparse
import json
import os
import re
import sys
@ -158,6 +159,8 @@ import librosa
import soundfile as sf
from fastapi import FastAPI, Request, Query
from fastapi.responses import StreamingResponse, JSONResponse
import base64
import io
import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
@ -609,6 +612,112 @@ def get_phones_and_bert(text, language, version, final=False):
return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text
def _viterbi_monotonic(p: torch.Tensor):
T, N = p.shape
if T < N:
return p.argmax(dim=-1)
eps = 1e-8
cost = -torch.log(p + eps)
dp = torch.empty((T, N), dtype=cost.dtype, device=cost.device)
prev = torch.zeros((T, N), dtype=torch.uint8, device=cost.device)
dp[0, 0] = cost[0, 0]
if N > 1:
dp[0, 1:] = float("inf")
for t_i in range(1, T):
dp[t_i, 0] = dp[t_i - 1, 0] + cost[t_i, 0]
prev[t_i, 0] = 0
if N > 1:
stay = dp[t_i - 1, 1:] + cost[t_i, 1:]
move = dp[t_i - 1, :-1] + cost[t_i, 1:]
better_move = move < stay
dp[t_i, 1:] = torch.where(better_move, move, stay)
prev[t_i, 1:] = better_move.to(torch.uint8)
j = N - 1
assign_bt = torch.empty(T, dtype=torch.long, device=cost.device)
for t_i in range(T - 1, -1, -1):
assign_bt[t_i] = j
if t_i > 0 and prev[t_i, j] == 1:
j = max(j - 1, 0)
return assign_bt
def _build_mixed_mappings_for_api(_text: str, _ui_lang: str, _version: str):
_text = re.sub(r' {2,}', ' ', _text)
textlist = []
langlist = []
if _ui_lang == "all_zh":
for tmp in LangSegmenter.getTexts(_text, "zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "all_yue":
for tmp in LangSegmenter.getTexts(_text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "all_ja":
for tmp in LangSegmenter.getTexts(_text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "all_ko":
for tmp in LangSegmenter.getTexts(_text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "en":
langlist.append("en")
textlist.append(_text)
elif _ui_lang == "auto":
for tmp in LangSegmenter.getTexts(_text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif _ui_lang == "auto_yue":
for tmp in LangSegmenter.getTexts(_text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(_text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
langlist.append(_ui_lang)
textlist.append(tmp["text"])
ph_to_char = []
ph_to_word = []
word_tokens = []
norm_text_agg = []
import re as _re
for seg_text, seg_lang in zip(textlist, langlist):
seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
norm_text_agg.append(seg_norm)
if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
char_base_idx = len("".join(norm_text_agg[:-1]))
for ch_idx, cnt in enumerate(seg_word2ph):
global_char_idx = char_base_idx + ch_idx
ph_to_char += [global_char_idx] * cnt
ph_to_word += [-1] * len(seg_phones)
elif seg_lang in {"en", "ko"} and seg_word2ph:
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)]
base_word_idx = len(word_tokens)
t_idx = 0
for cnt in seg_word2ph:
if t_idx < len(tokens_seg):
ph_to_word += [base_word_idx + t_idx] * cnt
t_idx += 1
ph_to_char += [-1] * len(seg_phones)
word_tokens.extend(tokens_seg)
else:
ph_to_char += [-1] * len(seg_phones)
ph_to_word += [-1] * len(seg_phones)
norm_text_agg = "".join(norm_text_agg)
return norm_text_agg, ph_to_char, ph_to_word, word_tokens
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
@ -1146,6 +1255,9 @@ def handle(
return StreamingResponse(gen, media_type="audio/" + media_type)
## /v1/tts_json endpoint moved below after app is initialized
# --------------------------------
# 初始化部分
# --------------------------------
@ -1385,301 +1497,311 @@ async def tts_endpoint(
)
def _fmt_srt_time(t):
h = int(t // 3600)
m = int((t % 3600) // 60)
s = int(t % 60)
ms = int(round((t - int(t)) * 1000))
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
@app.post("/v1/tts_json")
async def tts_json_post(request: Request):
data = await request.json()
try:
# 直接进行一次非流式推理并输出音频+时间戳v2/v2Pro/v2ProPlus带对齐
infer_sovits = speaker_list["default"].sovits
vq_model = infer_sovits.vq_model
hps = infer_sovits.hps
version = vq_model.version
infer_gpt = speaker_list["default"].gpt
t2s_model = infer_gpt.t2s_model
max_sec = infer_gpt.max_sec
refer_wav_path = data.get("refer_wav_path")
prompt_text = data.get("prompt_text") or ""
prompt_language = dict_language[data.get("prompt_language", "zh").lower()]
text = data.get("text") or ""
text_language = dict_language[data.get("text_language", "zh").lower()]
top_k = int(data.get("top_k", 15))
top_p = float(data.get("top_p", 1.0))
temperature = float(data.get("temperature", 1.0))
speed = float(data.get("speed", 1.0))
inp_refs = data.get("inp_refs", [])
sample_steps = int(data.get("sample_steps", 32))
if_sr = bool(data.get("if_sr", False))
def _build_norm_to_orig_map(orig, norm):
all_punc_local = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mapping = [-1] * len(norm)
o = 0
n = 0
while n < len(norm) and o < len(orig):
if norm[n] == orig[o]:
mapping[n] = o
n += 1
o += 1
else:
if orig[o].isspace() or orig[o] in all_punc_local:
o += 1
elif norm[n].isspace() or norm[n] in all_punc_local:
n += 1
if is_empty(refer_wav_path, prompt_text, prompt_language):
refer_wav_path, prompt_text, prompt_language = (
default_refer.path,
default_refer.text,
default_refer.language,
)
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
prompt_text = prompt_text.strip("\n")
if len(prompt_text) > 0 and prompt_text[-1] not in splits:
prompt_text += "" if prompt_language != "en" else "."
text = text.strip("\n")
if len(text) > 0 and text[-1] not in splits:
text += "" if text_language != "en" else "."
dtype = torch.float16 if is_half == True else torch.float32
with torch.no_grad():
wav16k, sr_r = librosa.load(refer_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
n += 1
return mapping
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
def synthesize_json(
refer_wav_path,
prompt_text,
prompt_language,
text,
text_language,
top_k,
top_p,
temperature,
speed,
inp_refs,
sample_steps,
if_sr,
):
infer_sovits = speaker_list["default"].sovits
vq_model = infer_sovits.vq_model
hps = infer_sovits.hps
version = vq_model.version
if version in {"v3", "v4"}:
return JSONResponse({"code": 400, "message": "v3/v4 暂未提供时间戳JSON接口"}, status_code=400)
# prepare refer features (same as handle)
dtype = torch.float16 if is_half else torch.float32
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(refer_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
is_v2pro = version in {"v2Pro", "v2ProPlus"}
refers = []
if is_v2pro:
sv_emb = []
if sv_cn_model == None:
init_sv_cn()
spec, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro)
refers = [spec]
if is_v2pro:
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
# text frontend
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
texts = text.strip("\n").split("\n")
# iterate
audio_segments = []
timestamps_all = []
sr_hz = int(hps.data.sampling_rate)
elapsed_s = 0.0
for seg in texts:
if only_punc(seg):
continue
if seg[-1] not in splits:
seg += "" if text_language != "en" else "."
phones2, bert2, norm_text2 = get_phones_and_bert(seg, text_language, version)
phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version)
phones2, bert2, norm_text = get_phones_and_bert(text, text_language, version)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
with torch.no_grad():
pred_semantic, idx = speaker_list["default"].gpt.t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_len, prompt, bert, top_k=top_k, top_p=top_p, temperature=temperature, early_stop_num=hz * speaker_list["default"].gpt.max_sec,
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0)
if is_v2pro:
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb
)
is_v2pro = version in {"v2Pro", "v2ProPlus"}
attn = None
if version not in {"v3", "v4"}:
refers = []
if is_v2pro:
sv_emb = []
if sv_cn_model == None:
init_sv_cn()
if inp_refs:
for path in inp_refs:
try:
refer, audio_tensor = get_spepc(hps, path, dtype, device, is_v2pro)
refers.append(refer)
if is_v2pro:
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
except Exception as e:
logger.error(e)
if len(refers) == 0:
refers, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro)
refers = [refers]
if is_v2pro:
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0)
if is_v2pro:
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb
)
else:
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic, phones2_tensor, refers, speed=speed
)
audio = o[0][0].detach().cpu().numpy()
else:
o, attn, y_mask = vq_model.decode_with_alignment(
pred_semantic, phones2_tensor, refers, speed=speed
)
audio = o[0][0].detach().cpu().numpy()
# timestamps
frame_time = 0.02 / max(float(speed), 1e-6)
ph_spans = []
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
refer, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device)
ref_audio, sr0 = torchaudio.load(refer_wav_path)
ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0)
tgt_sr = 24000 if version == "v3" else 32000
if sr0 != tgt_sr:
ref_audio = resample(ref_audio, sr0, tgt_sr, device)
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2)
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
Tref = 468 if version == "v3" else 500
Tchunk = 934 if version == "v3" else 1000
if T_min > Tref:
mel2 = mel2[:, :, -Tref:]
fea_ref = fea_ref[:, :, -Tref:]
T_min = Tref
chunk_len = Tchunk - T_min
mel2 = mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
cfm_resss = []
idx_p = 0
while True:
fea_todo_chunk = fea_todo[:, :, idx_p : idx_p + chunk_len]
if fea_todo_chunk.shape[-1] == 0:
break
idx_p += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_res = vq_model.cfm.inference(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
)
cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cfm_res = torch.cat(cfm_resss, 2)
cfm_res = denorm_spec(cfm_res)
if version == "v3":
if bigvgan_model == None:
init_bigvgan()
else:
if hifigan_model == None:
init_hifigan()
vocoder_model = bigvgan_model if version == "v3" else hifigan_model
with torch.inference_mode():
wav_gen = vocoder_model(cfm_res)
audio = wav_gen[0][0].cpu().detach().numpy()
max_audio = np.abs(audio).max()
if max_audio > 1:
audio = audio / max_audio
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
sr = 32000
elif version == "v3":
sr = 24000
else:
sr = 48000
if if_sr and sr == 24000:
audio_t = torch.from_numpy(audio).float().to(device)
audio_t, sr = audio_sr(audio_t.unsqueeze(0), sr)
audio = audio_t.cpu().detach().numpy()[0]
max_audio = np.abs(audio).max()
if max_audio > 1:
audio /= max_audio
sr = 48000
timestamps_all = []
if attn is not None:
attn_mean = attn.mean(dim=1)[0]
assign = attn_mean.argmax(dim=-1)
attn_heads_mean = attn.mean(dim=1)[0]
assign = _viterbi_monotonic(attn_heads_mean)
frame_time = 0.02 / max(speed, 1e-6)
ph_spans = []
if assign.numel() > 0:
start_f = 0
cur_ph = int(assign[0].item())
for f in range(1, assign.shape[0]):
ph = int(assign[f].item())
if ph != cur_ph:
ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": f * frame_time})
ph_spans.append({
"phoneme_id": cur_ph,
"start_s": start_f * frame_time,
"end_s": f * frame_time,
})
start_f = f
cur_ph = ph
ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": assign.shape[0] * frame_time})
ph_spans.append({
"phoneme_id": cur_ph,
"start_s": start_f * frame_time,
"end_s": assign.shape[0] * frame_time,
})
# char aggregation
_, word2ph, norm_text_seg = clean_text_inf(seg, text_language, version)
char_spans = []
if word2ph:
ph_to_char = []
for ch_idx, repeat in enumerate(word2ph):
ph_to_char += [ch_idx] * repeat
if ph_spans and ph_to_char:
norm_text_seg, ph_to_char_map, ph_to_word_map, word_tokens = _build_mixed_mappings_for_api(text, text_language, version)
char_spans = []
if ph_spans and ph_to_char_map and len(ph_to_char_map) >= 1:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_char):
char_idx = ph_to_char[ph_idx]
char_spans.append({"char_index": char_idx, "char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", "start_s": span["start_s"], "end_s": span["end_s"]})
# merge by char_index
if char_spans:
groups = {}
for cs in char_spans:
groups.setdefault(cs["char_index"], []).append(cs)
merged = []
gap_merge_s = 0.08
min_dur_s = max(0.015, frame_time)
for ci, lst in groups.items():
lst = sorted(lst, key=lambda x: x["start_s"])
cur = None
for it in lst:
if cur is None:
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
else:
if it["start_s"] - cur["end_s"] <= gap_merge_s:
if it["end_s"] > cur["end_s"]:
cur["end_s"] = it["end_s"]
if 0 <= ph_idx < len(ph_to_char_map):
ci = ph_to_char_map[ph_idx]
if ci == -1:
continue
if len(char_spans) == 0 or char_spans[-1]["char_index"] != ci:
char_spans.append({
"char_index": ci,
"char": norm_text_seg[ci] if ci < len(norm_text_seg) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
if cur["end_s"] - cur["start_s"] >= min_dur_s:
merged.append(cur)
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
merged.append(cur)
# remap to original input text
norm2orig = _build_norm_to_orig_map(seg, norm_text_seg)
remapped = []
punc = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
for cs in merged:
ci = cs["char_index"]
oi = norm2orig[ci] if ci < len(norm2orig) else -1
ch_norm = cs.get("char", "")
if oi != -1:
cs["char"] = seg[oi]
remapped.append(cs)
else:
if ch_norm and (ch_norm in punc or ch_norm.isspace()):
continue
remapped.append(cs)
char_spans = sorted(remapped, key=lambda x: x["start_s"]) if remapped else []
char_spans[-1]["end_s"] = span["end_s"]
word_spans = []
if ph_spans and ph_to_word_map and len(ph_to_word_map) >= 1:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_word_map):
wi = ph_to_word_map[ph_idx]
if wi == -1:
continue
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
word_spans.append({
"word_index": wi,
"word": word_tokens[wi] if wi < len(word_tokens) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
word_spans[-1]["end_s"] = span["end_s"]
# offset and record
audio_len_s = float(audio.shape[0]) / sr_hz
for coll in (char_spans,):
for d in coll:
d["start_s"] += elapsed_s
d["end_s"] += elapsed_s
timestamps_all.append({
"segment_index": len(timestamps_all),
"char_spans": char_spans,
"text": norm_text_seg,
"segment_start_s": elapsed_s,
"segment_end_s": elapsed_s + audio_len_s,
})
elapsed_s += audio_len_s + 0.3
audio_segments.append(audio)
audio_len_s = float(len(audio)) / float(sr)
timestamps_all.append({
"segment_index": 0,
"phoneme_spans": ph_spans,
"char_spans": char_spans,
"word_spans": word_spans,
"segment_start_s": 0.0,
"segment_end_s": audio_len_s,
"text": norm_text_seg,
})
# concatenate audio
if len(audio_segments) == 0:
return JSONResponse({"code": 400, "message": "无有效文本"}, status_code=400)
pad = np.zeros(int(sr_hz * 0.3), dtype=audio_segments[0].dtype)
out = []
for i, a in enumerate(audio_segments):
out.append(a)
if i < len(audio_segments) - 1:
out.append(pad)
audio_np = np.concatenate(out, 0)
mx = np.abs(audio_np).max()
if mx > 1:
audio_np = audio_np / mx
import soundfile as sf
bio = io.BytesIO()
sf.write(bio, (audio * 32767).astype("int16"), sr, format="WAV", subtype="PCM_16")
wav_b64 = base64.b64encode(bio.getvalue()).decode("ascii")
# srt
srt_lines = []
idx_counter = 1
for rec in timestamps_all:
if rec.get("char_spans"):
for c in rec["char_spans"]:
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}")
srt_lines.append(c.get("char", ""))
srt_lines.append("")
idx_counter += 1
else:
st = rec.get("segment_start_s")
ed = rec.get("segment_end_s")
if st is not None and ed is not None:
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
srt_lines.append(rec.get("text", ""))
srt_lines.append("")
idx_counter += 1
srt_text = "\n".join(srt_lines)
def _fmt_srt_time(t):
h = int(t // 3600)
m = int((t % 3600) // 60)
s = int(t % 60)
ms = int(round((t - int(t)) * 1000))
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
# pack wav
import base64
buf = BytesIO()
sf.write(buf, audio_np, sr_hz, format="WAV")
wav_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return JSONResponse({"code": 0, "sr": sr_hz, "audio_wav_base64": wav_b64, "timestamps": timestamps_all, "srt": srt_text})
srt_b64 = None
if timestamps_all:
srt_lines = []
idx_counter = 1
for rec in timestamps_all:
cs = rec.get("char_spans") or []
ws = rec.get("word_spans") or []
entries = []
for c in cs:
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]})
for w in ws:
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
if entries:
entries.sort(key=lambda x: x["start"])
for e in entries:
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}")
srt_lines.append(e["text"])
srt_lines.append("")
idx_counter += 1
srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii")
timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii")
@app.post("/tts_json")
async def tts_json_post(request: Request):
body = await request.json()
return synthesize_json(
body.get("refer_wav_path"),
body.get("prompt_text"),
body.get("prompt_language"),
body.get("text"),
body.get("text_language"),
body.get("top_k", 15),
body.get("top_p", 0.6),
body.get("temperature", 0.6),
body.get("speed", 1.0),
body.get("inp_refs", []),
body.get("sample_steps", 32),
body.get("if_sr", False),
)
@app.get("/tts_json")
async def tts_json_get(
refer_wav_path: str,
prompt_text: str,
prompt_language: str,
text: str,
text_language: str,
top_k: int = 15,
top_p: float = 0.6,
temperature: float = 0.6,
speed: float = 1.0,
inp_refs: list = Query(default=[]),
sample_steps: int = 32,
if_sr: bool = False,
):
return synthesize_json(
refer_wav_path,
prompt_text,
prompt_language,
text,
text_language,
top_k,
top_p,
temperature,
speed,
inp_refs,
sample_steps,
if_sr,
)
resp = {
"sr": sr,
"audio_wav_base64": wav_b64,
"timestamps": timestamps_all,
"timestamps_json_base64": timestamps_json_b64,
}
if srt_b64:
resp["srt_base64"] = srt_b64
return JSONResponse({"code": 0, "data": resp})
except Exception as e:
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
if __name__ == "__main__":