mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 05:36:34 +08:00
Supports phoneme and word-level timestamp output for multilingual text
This commit is contained in:
parent
2ff9a1533f
commit
19ca3f3f6a
@ -932,7 +932,40 @@ def get_tts_wav(
|
||||
# derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640)
|
||||
if last_attn is not None:
|
||||
attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text]
|
||||
assign = attn_heads_mean.argmax(dim=-1) # [T_ssl]
|
||||
# Viterbi monotonic alignment (stay or advance by 1)
|
||||
def _viterbi_monotonic(p):
|
||||
T, N = p.shape
|
||||
if T < N:
|
||||
return p.argmax(dim=-1)
|
||||
eps = 1e-8
|
||||
cost = -torch.log(p + eps)
|
||||
dp = torch.empty((T, N), dtype=cost.dtype, device=cost.device)
|
||||
prev = torch.zeros((T, N), dtype=torch.uint8, device=cost.device)
|
||||
dp[0, 0] = cost[0, 0]
|
||||
if N > 1:
|
||||
dp[0, 1:] = float("inf")
|
||||
for t_i in range(1, T):
|
||||
# j=0: only stay
|
||||
dp[t_i, 0] = dp[t_i - 1, 0] + cost[t_i, 0]
|
||||
prev[t_i, 0] = 0
|
||||
if N > 1:
|
||||
stay = dp[t_i - 1, 1:] + cost[t_i, 1:]
|
||||
move = dp[t_i - 1, :-1] + cost[t_i, 1:]
|
||||
better_move = move < stay
|
||||
dp[t_i, 1:] = torch.where(better_move, move, stay)
|
||||
prev[t_i, 1:] = better_move.to(torch.uint8)
|
||||
# backtrack from (T-1, N-1)
|
||||
j = N - 1
|
||||
assign_bt = torch.empty(T, dtype=torch.long, device=cost.device)
|
||||
for t_i in range(T - 1, -1, -1):
|
||||
assign_bt[t_i] = j
|
||||
if t_i > 0 and prev[t_i, j] == 1:
|
||||
j = j - 1
|
||||
if j < 0:
|
||||
j = 0
|
||||
return assign_bt
|
||||
|
||||
assign = _viterbi_monotonic(attn_heads_mean)
|
||||
frame_time = 0.02 / max(speed, 1e-6)
|
||||
# collapse consecutive frames pointing to same phoneme id
|
||||
ph_spans = []
|
||||
@ -954,38 +987,116 @@ def get_tts_wav(
|
||||
"start_s": start_f * frame_time,
|
||||
"end_s": assign.shape[0] * frame_time,
|
||||
})
|
||||
# char/word aggregation
|
||||
# obtain word2ph for current text segment
|
||||
_, word2ph, norm_text_seg = clean_text_inf(text, text_language, version)
|
||||
# char spans (for zh/yue where word2ph is char-based)
|
||||
char_spans = []
|
||||
if word2ph:
|
||||
# char/word aggregation for multi-language text
|
||||
def _build_mixed_mappings(_text, _ui_lang, _version):
|
||||
# replicate segmentation logic from get_phones_and_bert
|
||||
_text = re.sub(r' {2,}', ' ', _text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if _ui_lang == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(_text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(_text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(_text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(_text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(_text)
|
||||
elif _ui_lang == "auto":
|
||||
for tmp in LangSegmenter.getTexts(_text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(_text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(_text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
langlist.append(_ui_lang)
|
||||
textlist.append(tmp["text"])
|
||||
|
||||
# aggregate mappings
|
||||
ph_to_char = []
|
||||
for ch_idx, repeat in enumerate(word2ph):
|
||||
ph_to_char += [ch_idx] * repeat
|
||||
if ph_spans and ph_to_char:
|
||||
ph_to_word = []
|
||||
word_tokens = []
|
||||
norm_text_agg = []
|
||||
import re as _re
|
||||
for seg_text, seg_lang in zip(textlist, langlist):
|
||||
seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
|
||||
norm_text_agg.append(seg_norm)
|
||||
if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
|
||||
# char-based
|
||||
char_base_idx = len("".join(norm_text_agg[:-1]))
|
||||
for ch_idx, cnt in enumerate(seg_word2ph):
|
||||
global_char_idx = char_base_idx + ch_idx
|
||||
ph_to_char += [global_char_idx] * cnt
|
||||
ph_to_word += [-1] * len(seg_phones)
|
||||
elif seg_lang in {"en", "ko"} and seg_word2ph:
|
||||
# word-based
|
||||
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)]
|
||||
base_word_idx = len(word_tokens)
|
||||
t_idx = 0
|
||||
for cnt in seg_word2ph:
|
||||
if t_idx < len(tokens_seg):
|
||||
ph_to_word += [base_word_idx + t_idx] * cnt
|
||||
t_idx += 1
|
||||
ph_to_char += [-1] * len(seg_phones)
|
||||
word_tokens.extend(tokens_seg)
|
||||
else:
|
||||
# unknown mapping; fill with -1
|
||||
ph_to_char += [-1] * len(seg_phones)
|
||||
ph_to_word += [-1] * len(seg_phones)
|
||||
norm_text_agg = "".join(norm_text_agg)
|
||||
return norm_text_agg, ph_to_char, ph_to_word, word_tokens
|
||||
|
||||
norm_text_seg, ph_to_char_map, ph_to_word_map, word_tokens = _build_mixed_mappings(text, text_language, version)
|
||||
|
||||
# char spans (for zh/yue/ja parts only)
|
||||
char_spans = []
|
||||
if ph_spans and ph_to_char_map and len(ph_to_char_map) >= 1:
|
||||
for span in ph_spans:
|
||||
ph_idx = span["phoneme_id"]
|
||||
if 0 <= ph_idx < len(ph_to_char):
|
||||
char_idx = ph_to_char[ph_idx]
|
||||
if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx:
|
||||
if 0 <= ph_idx < len(ph_to_char_map):
|
||||
ci = ph_to_char_map[ph_idx]
|
||||
if ci == -1:
|
||||
continue
|
||||
if len(char_spans) == 0 or char_spans[-1]["char_index"] != ci:
|
||||
char_spans.append({
|
||||
"char_index": char_idx,
|
||||
"char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "",
|
||||
"char_index": ci,
|
||||
"char": norm_text_seg[ci] if ci < len(norm_text_seg) else "",
|
||||
"start_s": span["start_s"],
|
||||
"end_s": span["end_s"],
|
||||
})
|
||||
else:
|
||||
char_spans[-1]["end_s"] = span["end_s"]
|
||||
# post-merge by char_index across the whole segment to remove jitter fragments
|
||||
# post-merge and remap for chars
|
||||
if char_spans:
|
||||
# group by char_index
|
||||
groups = {}
|
||||
for cs in char_spans:
|
||||
groups.setdefault(cs["char_index"], []).append(cs)
|
||||
merged = []
|
||||
gap_merge_s = 0.08
|
||||
# adaptive minimal duration: at least one frame, but not lower than 15ms
|
||||
min_dur_s = max(0.015, frame_time)
|
||||
for ci, lst in groups.items():
|
||||
lst = sorted(lst, key=lambda x: x["start_s"])
|
||||
@ -1003,9 +1114,7 @@ def get_tts_wav(
|
||||
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
|
||||
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
|
||||
merged.append(cur)
|
||||
# sort merged by time
|
||||
char_spans = sorted(merged, key=lambda x: x["start_s"])
|
||||
# remap normalized chars back to original input text to avoid spurious '.'/'?'
|
||||
def _build_norm_to_orig_map(orig, norm):
|
||||
all_punc_local = set(punctuation).union(set(splits))
|
||||
mapping = [-1] * len(norm)
|
||||
@ -1017,13 +1126,11 @@ def get_tts_wav(
|
||||
n += 1
|
||||
o += 1
|
||||
else:
|
||||
# skip spaces/punctuations on either side
|
||||
if orig[o].isspace() or orig[o] in all_punc_local:
|
||||
o += 1
|
||||
elif norm[n].isspace() or norm[n] in all_punc_local:
|
||||
n += 1
|
||||
else:
|
||||
# characters differ (e.g., compatibility form). Advance normalized index.
|
||||
n += 1
|
||||
return mapping
|
||||
norm2orig = _build_norm_to_orig_map(text, norm_text_seg)
|
||||
@ -1037,34 +1144,24 @@ def get_tts_wav(
|
||||
cs["char"] = text[oi]
|
||||
remapped.append(cs)
|
||||
else:
|
||||
# drop normalized-only punctuations/spaces not present in original
|
||||
if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()):
|
||||
continue
|
||||
remapped.append(cs)
|
||||
char_spans = remapped
|
||||
# word spans
|
||||
|
||||
# word spans (for en/ko parts only)
|
||||
word_spans = []
|
||||
if text_language == "en":
|
||||
# build phoneme-to-word map using dictionary per-word g2p
|
||||
try:
|
||||
from text.english import g2p as g2p_en
|
||||
except Exception:
|
||||
g2p_en = None
|
||||
words = norm_text_seg.split()
|
||||
ph_to_word = []
|
||||
if g2p_en:
|
||||
for w_idx, w in enumerate(words):
|
||||
phs_w = g2p_en(w)
|
||||
ph_to_word += [w_idx] * len(phs_w)
|
||||
if ph_spans and ph_to_word:
|
||||
if ph_spans and ph_to_word_map and len(ph_to_word_map) >= 1:
|
||||
for span in ph_spans:
|
||||
ph_idx = span["phoneme_id"]
|
||||
if 0 <= ph_idx < len(ph_to_word):
|
||||
wi = ph_to_word[ph_idx]
|
||||
if 0 <= ph_idx < len(ph_to_word_map):
|
||||
wi = ph_to_word_map[ph_idx]
|
||||
if wi == -1:
|
||||
continue
|
||||
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
|
||||
word_spans.append({
|
||||
"word_index": wi,
|
||||
"word": words[wi] if wi < len(words) else "",
|
||||
"word": word_tokens[wi] if wi < len(word_tokens) else "",
|
||||
"start_s": span["start_s"],
|
||||
"end_s": span["end_s"],
|
||||
})
|
||||
@ -1185,31 +1282,24 @@ def get_tts_wav(
|
||||
srt_lines = []
|
||||
idx_counter = 1
|
||||
for rec in timestamps_all:
|
||||
# 优先按字分段
|
||||
if rec.get("char_spans") and len(rec["char_spans"]):
|
||||
for c in rec["char_spans"]:
|
||||
st = c["start_s"]
|
||||
ed = c["end_s"]
|
||||
txt = c.get("char", "")
|
||||
cs = rec.get("char_spans") or []
|
||||
ws = rec.get("word_spans") or []
|
||||
entries = []
|
||||
# 统一时间轴:存在则合并后按开始时间排序输出
|
||||
for c in cs:
|
||||
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]})
|
||||
for w in ws:
|
||||
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
|
||||
if entries:
|
||||
entries.sort(key=lambda x: x["start"])
|
||||
for e in entries:
|
||||
srt_lines.append(str(idx_counter))
|
||||
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
|
||||
srt_lines.append(txt)
|
||||
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}")
|
||||
srt_lines.append(e["text"])
|
||||
srt_lines.append("")
|
||||
idx_counter += 1
|
||||
continue
|
||||
# 次选按词(英文)
|
||||
if rec.get("word_spans") and len(rec["word_spans"]):
|
||||
for w in rec["word_spans"]:
|
||||
st = w["start_s"]
|
||||
ed = w["end_s"]
|
||||
txt = w.get("word", "")
|
||||
srt_lines.append(str(idx_counter))
|
||||
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
|
||||
srt_lines.append(txt)
|
||||
srt_lines.append("")
|
||||
idx_counter += 1
|
||||
continue
|
||||
# 最后按整段兜底
|
||||
# 兜底:整段
|
||||
st = rec.get("segment_start_s")
|
||||
ed = rec.get("segment_end_s")
|
||||
text_line = rec.get("text", "")
|
||||
@ -1228,8 +1318,17 @@ def get_tts_wav(
|
||||
except Exception:
|
||||
srt_path = None
|
||||
|
||||
# Return audio, timestamps and SRT path for UI
|
||||
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path
|
||||
# Also write JSON timestamps to temp file for download
|
||||
json_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w", encoding="utf-8") as fjson:
|
||||
json.dump(timestamps_all, fjson, ensure_ascii=False, indent=2)
|
||||
json_path = fjson.name
|
||||
except Exception:
|
||||
json_path = None
|
||||
|
||||
# Return audio, timestamps, SRT path and JSON path for UI
|
||||
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path, json_path
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
@ -1516,8 +1615,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
with gr.Row():
|
||||
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
||||
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
||||
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)"))
|
||||
with gr.Accordion(i18n("时间戳(音素/字/词) - 点击展开"), open=False):
|
||||
timestamps_box = gr.JSON(label=i18n("时间戳JSON预览"))
|
||||
srt_file = gr.File(label=i18n("下载SRT字幕"))
|
||||
json_file = gr.File(label=i18n("下载时间戳JSON"))
|
||||
|
||||
inference_button.click(
|
||||
get_tts_wav,
|
||||
@ -1539,7 +1640,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
if_sr_Checkbox,
|
||||
pause_second_slider,
|
||||
],
|
||||
[output, timestamps_box, srt_file],
|
||||
[output, timestamps_box, srt_file, json_file],
|
||||
)
|
||||
SoVITS_dropdown.change(
|
||||
change_sovits_weights,
|
||||
|
@ -39,18 +39,24 @@ def clean_text(text, language, version=None):
|
||||
norm_text = language_module.text_normalize(text)
|
||||
else:
|
||||
norm_text = text
|
||||
|
||||
if language == "zh" or language == "yue": ##########
|
||||
phones, word2ph = language_module.g2p(norm_text)
|
||||
assert len(phones) == sum(word2ph)
|
||||
assert len(norm_text) == len(word2ph)
|
||||
elif language == "en":
|
||||
else:
|
||||
# Try per-language word2ph helpers
|
||||
if hasattr(language_module, "g2p_with_word2ph"):
|
||||
try:
|
||||
phones, word2ph = language_module.g2p_with_word2ph(norm_text, keep_punc=False)
|
||||
except Exception:
|
||||
phones = language_module.g2p(norm_text)
|
||||
if len(phones) < 4:
|
||||
phones = [","] + phones
|
||||
word2ph = None
|
||||
else:
|
||||
phones = language_module.g2p(norm_text)
|
||||
word2ph = None
|
||||
if language == "en" and len(phones) < 4:
|
||||
phones = [","] + phones
|
||||
phones = ["UNK" if ph not in symbols else ph for ph in phones]
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
@ -368,6 +368,32 @@ def g2p(text):
|
||||
return replace_phs(phones)
|
||||
|
||||
|
||||
def g2p_with_word2ph(text, keep_punc=False):
|
||||
"""
|
||||
Returns (phones, word2ph) for English by whitespace tokenization.
|
||||
- Tokenize by spaces; for each token, call g2p(token)
|
||||
- word2ph per token = max(1, num_valid_phones) if keep_punc else skip pure punctuations
|
||||
"""
|
||||
tokens = re.split(r"(\s+)", text)
|
||||
phones_all = []
|
||||
word2ph = []
|
||||
punc_set = set(punctuation)
|
||||
for tok in tokens:
|
||||
if tok.strip() == "":
|
||||
if keep_punc:
|
||||
word2ph.append(1)
|
||||
continue
|
||||
if all((c in punc_set or c.isspace()) for c in tok):
|
||||
if keep_punc:
|
||||
word2ph.append(1)
|
||||
continue
|
||||
phs = g2p(tok)
|
||||
phs_valid = [p for p in phs if p not in punc_set]
|
||||
phones_all.extend(phs_valid)
|
||||
word2ph.append(max(1, len(phs_valid)))
|
||||
return phones_all, word2ph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(g2p("hello"))
|
||||
print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture.")))
|
||||
|
@ -271,6 +271,30 @@ def g2p(norm_text, with_prosody=True):
|
||||
return phones
|
||||
|
||||
|
||||
# Helper for alignment: build phones and word2ph by per-character g2p (ignoring prosody markers)
|
||||
def g2p_with_word2ph(text, keep_punc=False):
|
||||
"""
|
||||
Returns (phones, word2ph)
|
||||
- Per-character g2p; ignore prosody markers like '[', ']','^', '$', '#', '_'
|
||||
- Punctuation counted as 1 if keep_punc else skipped
|
||||
"""
|
||||
norm_text = text_normalize(text)
|
||||
phones_all = []
|
||||
word2ph = []
|
||||
prosody_markers = {'[', ']', '^', '$', '#', '_'}
|
||||
punc_set = set(punctuation)
|
||||
for ch in norm_text:
|
||||
if ch.isspace() or ch in punc_set:
|
||||
if keep_punc:
|
||||
word2ph.append(1)
|
||||
continue
|
||||
phs = preprocess_jap(ch, with_prosody=True)
|
||||
phs = [post_replace_ph(p) for p in phs if p not in prosody_markers and p not in punc_set]
|
||||
phones_all.extend(phs)
|
||||
word2ph.append(max(1, len(phs)))
|
||||
return phones_all, word2ph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
phones = g2p("Hello.こんにちは!今日もNiCe天気ですね!tokyotowerに行きましょう!")
|
||||
print(phones)
|
||||
|
@ -332,6 +332,31 @@ def g2p(text):
|
||||
return text
|
||||
|
||||
|
||||
# Helper for alignment: build phones and word2ph by space-tokenizing Korean text
|
||||
def g2p_with_word2ph(text, keep_punc=False):
|
||||
"""
|
||||
Returns (phones, word2ph)
|
||||
- Tokenize by spaces; for each non-space token, call g2p(token)
|
||||
- Exclude separator/placeholders from phones: '.', '空', '停' and common punctuations
|
||||
- word2ph per token = max(1, num_valid_phones) if keep_punc else skip spaces entirely
|
||||
"""
|
||||
# prepare
|
||||
tokens = re.split(r"(\s+)", text)
|
||||
phones_all = []
|
||||
word2ph = []
|
||||
punc_set = {",", ".", "!", "?", "、", "。", ";", ":", ":"}
|
||||
for tok in tokens:
|
||||
if tok.strip() == "":
|
||||
if keep_punc:
|
||||
word2ph.append(1)
|
||||
continue
|
||||
phs = g2p(tok)
|
||||
# filter non-phonetic markers
|
||||
phs_valid = [p for p in phs if p not in {'.', '空', '停'} and p not in punc_set]
|
||||
phones_all.extend(phs_valid)
|
||||
word2ph.append(max(1, len(phs_valid)))
|
||||
return phones_all, word2ph
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "안녕하세요"
|
||||
print(g2p(text))
|
||||
|
586
api.py
586
api.py
@ -141,6 +141,7 @@ RESP: 无
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@ -158,6 +159,8 @@ import librosa
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, Query
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import base64
|
||||
import io
|
||||
import uvicorn
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
@ -609,6 +612,112 @@ def get_phones_and_bert(text, language, version, final=False):
|
||||
return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text
|
||||
|
||||
|
||||
def _viterbi_monotonic(p: torch.Tensor):
|
||||
T, N = p.shape
|
||||
if T < N:
|
||||
return p.argmax(dim=-1)
|
||||
eps = 1e-8
|
||||
cost = -torch.log(p + eps)
|
||||
dp = torch.empty((T, N), dtype=cost.dtype, device=cost.device)
|
||||
prev = torch.zeros((T, N), dtype=torch.uint8, device=cost.device)
|
||||
dp[0, 0] = cost[0, 0]
|
||||
if N > 1:
|
||||
dp[0, 1:] = float("inf")
|
||||
for t_i in range(1, T):
|
||||
dp[t_i, 0] = dp[t_i - 1, 0] + cost[t_i, 0]
|
||||
prev[t_i, 0] = 0
|
||||
if N > 1:
|
||||
stay = dp[t_i - 1, 1:] + cost[t_i, 1:]
|
||||
move = dp[t_i - 1, :-1] + cost[t_i, 1:]
|
||||
better_move = move < stay
|
||||
dp[t_i, 1:] = torch.where(better_move, move, stay)
|
||||
prev[t_i, 1:] = better_move.to(torch.uint8)
|
||||
j = N - 1
|
||||
assign_bt = torch.empty(T, dtype=torch.long, device=cost.device)
|
||||
for t_i in range(T - 1, -1, -1):
|
||||
assign_bt[t_i] = j
|
||||
if t_i > 0 and prev[t_i, j] == 1:
|
||||
j = max(j - 1, 0)
|
||||
return assign_bt
|
||||
|
||||
|
||||
def _build_mixed_mappings_for_api(_text: str, _ui_lang: str, _version: str):
|
||||
_text = re.sub(r' {2,}', ' ', _text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if _ui_lang == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(_text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(_text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(_text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(_text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(_text)
|
||||
elif _ui_lang == "auto":
|
||||
for tmp in LangSegmenter.getTexts(_text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif _ui_lang == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(_text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(_text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
langlist.append(_ui_lang)
|
||||
textlist.append(tmp["text"])
|
||||
|
||||
ph_to_char = []
|
||||
ph_to_word = []
|
||||
word_tokens = []
|
||||
norm_text_agg = []
|
||||
import re as _re
|
||||
for seg_text, seg_lang in zip(textlist, langlist):
|
||||
seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
|
||||
norm_text_agg.append(seg_norm)
|
||||
if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
|
||||
char_base_idx = len("".join(norm_text_agg[:-1]))
|
||||
for ch_idx, cnt in enumerate(seg_word2ph):
|
||||
global_char_idx = char_base_idx + ch_idx
|
||||
ph_to_char += [global_char_idx] * cnt
|
||||
ph_to_word += [-1] * len(seg_phones)
|
||||
elif seg_lang in {"en", "ko"} and seg_word2ph:
|
||||
tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)]
|
||||
base_word_idx = len(word_tokens)
|
||||
t_idx = 0
|
||||
for cnt in seg_word2ph:
|
||||
if t_idx < len(tokens_seg):
|
||||
ph_to_word += [base_word_idx + t_idx] * cnt
|
||||
t_idx += 1
|
||||
ph_to_char += [-1] * len(seg_phones)
|
||||
word_tokens.extend(tokens_seg)
|
||||
else:
|
||||
ph_to_char += [-1] * len(seg_phones)
|
||||
ph_to_word += [-1] * len(seg_phones)
|
||||
norm_text_agg = "".join(norm_text_agg)
|
||||
return norm_text_agg, ph_to_char, ph_to_word, word_tokens
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
@ -1146,6 +1255,9 @@ def handle(
|
||||
return StreamingResponse(gen, media_type="audio/" + media_type)
|
||||
|
||||
|
||||
## /v1/tts_json endpoint moved below after app is initialized
|
||||
|
||||
|
||||
# --------------------------------
|
||||
# 初始化部分
|
||||
# --------------------------------
|
||||
@ -1385,64 +1497,55 @@ async def tts_endpoint(
|
||||
)
|
||||
|
||||
|
||||
def _fmt_srt_time(t):
|
||||
h = int(t // 3600)
|
||||
m = int((t % 3600) // 60)
|
||||
s = int(t % 60)
|
||||
ms = int(round((t - int(t)) * 1000))
|
||||
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||||
|
||||
|
||||
def _build_norm_to_orig_map(orig, norm):
|
||||
all_punc_local = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||
mapping = [-1] * len(norm)
|
||||
o = 0
|
||||
n = 0
|
||||
while n < len(norm) and o < len(orig):
|
||||
if norm[n] == orig[o]:
|
||||
mapping[n] = o
|
||||
n += 1
|
||||
o += 1
|
||||
else:
|
||||
if orig[o].isspace() or orig[o] in all_punc_local:
|
||||
o += 1
|
||||
elif norm[n].isspace() or norm[n] in all_punc_local:
|
||||
n += 1
|
||||
else:
|
||||
n += 1
|
||||
return mapping
|
||||
|
||||
|
||||
def synthesize_json(
|
||||
refer_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
speed,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
):
|
||||
@app.post("/v1/tts_json")
|
||||
async def tts_json_post(request: Request):
|
||||
data = await request.json()
|
||||
try:
|
||||
# 直接进行一次非流式推理并输出音频+时间戳(v2/v2Pro/v2ProPlus带对齐)
|
||||
infer_sovits = speaker_list["default"].sovits
|
||||
vq_model = infer_sovits.vq_model
|
||||
hps = infer_sovits.hps
|
||||
version = vq_model.version
|
||||
infer_gpt = speaker_list["default"].gpt
|
||||
t2s_model = infer_gpt.t2s_model
|
||||
max_sec = infer_gpt.max_sec
|
||||
|
||||
if version in {"v3", "v4"}:
|
||||
return JSONResponse({"code": 400, "message": "v3/v4 暂未提供时间戳JSON接口"}, status_code=400)
|
||||
refer_wav_path = data.get("refer_wav_path")
|
||||
prompt_text = data.get("prompt_text") or ""
|
||||
prompt_language = dict_language[data.get("prompt_language", "zh").lower()]
|
||||
text = data.get("text") or ""
|
||||
text_language = dict_language[data.get("text_language", "zh").lower()]
|
||||
top_k = int(data.get("top_k", 15))
|
||||
top_p = float(data.get("top_p", 1.0))
|
||||
temperature = float(data.get("temperature", 1.0))
|
||||
speed = float(data.get("speed", 1.0))
|
||||
inp_refs = data.get("inp_refs", [])
|
||||
sample_steps = int(data.get("sample_steps", 32))
|
||||
if_sr = bool(data.get("if_sr", False))
|
||||
|
||||
# prepare refer features (same as handle)
|
||||
dtype = torch.float16 if is_half else torch.float32
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half else np.float32)
|
||||
if is_empty(refer_wav_path, prompt_text, prompt_language):
|
||||
refer_wav_path, prompt_text, prompt_language = (
|
||||
default_refer.path,
|
||||
default_refer.text,
|
||||
default_refer.language,
|
||||
)
|
||||
if not default_refer.is_ready():
|
||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if len(prompt_text) > 0 and prompt_text[-1] not in splits:
|
||||
prompt_text += "。" if prompt_language != "en" else "."
|
||||
text = text.strip("\n")
|
||||
if len(text) > 0 and text[-1] not in splits:
|
||||
text += "。" if text_language != "en" else "."
|
||||
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(refer_wav_path, sr=16000)
|
||||
wav16k, sr_r = librosa.load(refer_wav_path, sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if is_half:
|
||||
if is_half == True:
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
@ -1454,43 +1557,48 @@ def synthesize_json(
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
phones2, bert2, norm_text = get_phones_and_bert(text, text_language, version)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
|
||||
is_v2pro = version in {"v2Pro", "v2ProPlus"}
|
||||
attn = None
|
||||
if version not in {"v3", "v4"}:
|
||||
refers = []
|
||||
if is_v2pro:
|
||||
sv_emb = []
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn()
|
||||
spec, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro)
|
||||
refers = [spec]
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer, audio_tensor = get_spepc(hps, path, dtype, device, is_v2pro)
|
||||
refers.append(refer)
|
||||
if is_v2pro:
|
||||
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if len(refers) == 0:
|
||||
refers, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro)
|
||||
refers = [refers]
|
||||
if is_v2pro:
|
||||
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
|
||||
|
||||
# text frontend
|
||||
prompt_language = dict_language[prompt_language.lower()]
|
||||
text_language = dict_language[text_language.lower()]
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
texts = text.strip("\n").split("\n")
|
||||
|
||||
# iterate
|
||||
audio_segments = []
|
||||
timestamps_all = []
|
||||
sr_hz = int(hps.data.sampling_rate)
|
||||
elapsed_s = 0.0
|
||||
for seg in texts:
|
||||
if only_punc(seg):
|
||||
continue
|
||||
if seg[-1] not in splits:
|
||||
seg += "。" if text_language != "en" else "."
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(seg, text_language, version)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
with torch.no_grad():
|
||||
pred_semantic, idx = speaker_list["default"].gpt.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids, all_phoneme_len, prompt, bert, top_k=top_k, top_p=top_p, temperature=temperature, early_stop_num=hz * speaker_list["default"].gpt.max_sec,
|
||||
)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
if is_v2pro:
|
||||
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||
@ -1501,185 +1609,199 @@ def synthesize_json(
|
||||
pred_semantic, phones2_tensor, refers, speed=speed
|
||||
)
|
||||
audio = o[0][0].detach().cpu().numpy()
|
||||
# timestamps
|
||||
frame_time = 0.02 / max(float(speed), 1e-6)
|
||||
ph_spans = []
|
||||
else:
|
||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
refer, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device)
|
||||
ref_audio, sr0 = torchaudio.load(refer_wav_path)
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
tgt_sr = 24000 if version == "v3" else 32000
|
||||
if sr0 != tgt_sr:
|
||||
ref_audio = resample(ref_audio, sr0, tgt_sr, device)
|
||||
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
Tref = 468 if version == "v3" else 500
|
||||
Tchunk = 934 if version == "v3" else 1000
|
||||
if T_min > Tref:
|
||||
mel2 = mel2[:, :, -Tref:]
|
||||
fea_ref = fea_ref[:, :, -Tref:]
|
||||
T_min = Tref
|
||||
chunk_len = Tchunk - T_min
|
||||
mel2 = mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||
cfm_resss = []
|
||||
idx_p = 0
|
||||
while True:
|
||||
fea_todo_chunk = fea_todo[:, :, idx_p : idx_p + chunk_len]
|
||||
if fea_todo_chunk.shape[-1] == 0:
|
||||
break
|
||||
idx_p += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
cfm_res = vq_model.cfm.inference(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||
)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
cfm_resss.append(cfm_res)
|
||||
cfm_res = torch.cat(cfm_resss, 2)
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
if version == "v3":
|
||||
if bigvgan_model == None:
|
||||
init_bigvgan()
|
||||
else:
|
||||
if hifigan_model == None:
|
||||
init_hifigan()
|
||||
vocoder_model = bigvgan_model if version == "v3" else hifigan_model
|
||||
with torch.inference_mode():
|
||||
wav_gen = vocoder_model(cfm_res)
|
||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
max_audio = np.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
|
||||
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||
sr = 32000
|
||||
elif version == "v3":
|
||||
sr = 24000
|
||||
else:
|
||||
sr = 48000
|
||||
if if_sr and sr == 24000:
|
||||
audio_t = torch.from_numpy(audio).float().to(device)
|
||||
audio_t, sr = audio_sr(audio_t.unsqueeze(0), sr)
|
||||
audio = audio_t.cpu().detach().numpy()[0]
|
||||
max_audio = np.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio /= max_audio
|
||||
sr = 48000
|
||||
|
||||
timestamps_all = []
|
||||
if attn is not None:
|
||||
attn_mean = attn.mean(dim=1)[0]
|
||||
assign = attn_mean.argmax(dim=-1)
|
||||
attn_heads_mean = attn.mean(dim=1)[0]
|
||||
assign = _viterbi_monotonic(attn_heads_mean)
|
||||
frame_time = 0.02 / max(speed, 1e-6)
|
||||
ph_spans = []
|
||||
if assign.numel() > 0:
|
||||
start_f = 0
|
||||
cur_ph = int(assign[0].item())
|
||||
for f in range(1, assign.shape[0]):
|
||||
ph = int(assign[f].item())
|
||||
if ph != cur_ph:
|
||||
ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": f * frame_time})
|
||||
ph_spans.append({
|
||||
"phoneme_id": cur_ph,
|
||||
"start_s": start_f * frame_time,
|
||||
"end_s": f * frame_time,
|
||||
})
|
||||
start_f = f
|
||||
cur_ph = ph
|
||||
ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": assign.shape[0] * frame_time})
|
||||
ph_spans.append({
|
||||
"phoneme_id": cur_ph,
|
||||
"start_s": start_f * frame_time,
|
||||
"end_s": assign.shape[0] * frame_time,
|
||||
})
|
||||
|
||||
norm_text_seg, ph_to_char_map, ph_to_word_map, word_tokens = _build_mixed_mappings_for_api(text, text_language, version)
|
||||
|
||||
# char aggregation
|
||||
_, word2ph, norm_text_seg = clean_text_inf(seg, text_language, version)
|
||||
char_spans = []
|
||||
if word2ph:
|
||||
ph_to_char = []
|
||||
for ch_idx, repeat in enumerate(word2ph):
|
||||
ph_to_char += [ch_idx] * repeat
|
||||
if ph_spans and ph_to_char:
|
||||
if ph_spans and ph_to_char_map and len(ph_to_char_map) >= 1:
|
||||
for span in ph_spans:
|
||||
ph_idx = span["phoneme_id"]
|
||||
if 0 <= ph_idx < len(ph_to_char):
|
||||
char_idx = ph_to_char[ph_idx]
|
||||
char_spans.append({"char_index": char_idx, "char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", "start_s": span["start_s"], "end_s": span["end_s"]})
|
||||
# merge by char_index
|
||||
if char_spans:
|
||||
groups = {}
|
||||
for cs in char_spans:
|
||||
groups.setdefault(cs["char_index"], []).append(cs)
|
||||
merged = []
|
||||
gap_merge_s = 0.08
|
||||
min_dur_s = max(0.015, frame_time)
|
||||
for ci, lst in groups.items():
|
||||
lst = sorted(lst, key=lambda x: x["start_s"])
|
||||
cur = None
|
||||
for it in lst:
|
||||
if cur is None:
|
||||
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
|
||||
else:
|
||||
if it["start_s"] - cur["end_s"] <= gap_merge_s:
|
||||
if it["end_s"] > cur["end_s"]:
|
||||
cur["end_s"] = it["end_s"]
|
||||
else:
|
||||
if cur["end_s"] - cur["start_s"] >= min_dur_s:
|
||||
merged.append(cur)
|
||||
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
|
||||
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
|
||||
merged.append(cur)
|
||||
# remap to original input text
|
||||
norm2orig = _build_norm_to_orig_map(seg, norm_text_seg)
|
||||
remapped = []
|
||||
punc = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||
for cs in merged:
|
||||
ci = cs["char_index"]
|
||||
oi = norm2orig[ci] if ci < len(norm2orig) else -1
|
||||
ch_norm = cs.get("char", "")
|
||||
if oi != -1:
|
||||
cs["char"] = seg[oi]
|
||||
remapped.append(cs)
|
||||
else:
|
||||
if ch_norm and (ch_norm in punc or ch_norm.isspace()):
|
||||
if 0 <= ph_idx < len(ph_to_char_map):
|
||||
ci = ph_to_char_map[ph_idx]
|
||||
if ci == -1:
|
||||
continue
|
||||
remapped.append(cs)
|
||||
char_spans = sorted(remapped, key=lambda x: x["start_s"]) if remapped else []
|
||||
|
||||
# offset and record
|
||||
audio_len_s = float(audio.shape[0]) / sr_hz
|
||||
for coll in (char_spans,):
|
||||
for d in coll:
|
||||
d["start_s"] += elapsed_s
|
||||
d["end_s"] += elapsed_s
|
||||
timestamps_all.append({
|
||||
"segment_index": len(timestamps_all),
|
||||
"char_spans": char_spans,
|
||||
"text": norm_text_seg,
|
||||
"segment_start_s": elapsed_s,
|
||||
"segment_end_s": elapsed_s + audio_len_s,
|
||||
if len(char_spans) == 0 or char_spans[-1]["char_index"] != ci:
|
||||
char_spans.append({
|
||||
"char_index": ci,
|
||||
"char": norm_text_seg[ci] if ci < len(norm_text_seg) else "",
|
||||
"start_s": span["start_s"],
|
||||
"end_s": span["end_s"],
|
||||
})
|
||||
elapsed_s += audio_len_s + 0.3
|
||||
audio_segments.append(audio)
|
||||
else:
|
||||
char_spans[-1]["end_s"] = span["end_s"]
|
||||
word_spans = []
|
||||
if ph_spans and ph_to_word_map and len(ph_to_word_map) >= 1:
|
||||
for span in ph_spans:
|
||||
ph_idx = span["phoneme_id"]
|
||||
if 0 <= ph_idx < len(ph_to_word_map):
|
||||
wi = ph_to_word_map[ph_idx]
|
||||
if wi == -1:
|
||||
continue
|
||||
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
|
||||
word_spans.append({
|
||||
"word_index": wi,
|
||||
"word": word_tokens[wi] if wi < len(word_tokens) else "",
|
||||
"start_s": span["start_s"],
|
||||
"end_s": span["end_s"],
|
||||
})
|
||||
else:
|
||||
word_spans[-1]["end_s"] = span["end_s"]
|
||||
|
||||
# concatenate audio
|
||||
if len(audio_segments) == 0:
|
||||
return JSONResponse({"code": 400, "message": "无有效文本"}, status_code=400)
|
||||
pad = np.zeros(int(sr_hz * 0.3), dtype=audio_segments[0].dtype)
|
||||
out = []
|
||||
for i, a in enumerate(audio_segments):
|
||||
out.append(a)
|
||||
if i < len(audio_segments) - 1:
|
||||
out.append(pad)
|
||||
audio_np = np.concatenate(out, 0)
|
||||
mx = np.abs(audio_np).max()
|
||||
if mx > 1:
|
||||
audio_np = audio_np / mx
|
||||
audio_len_s = float(len(audio)) / float(sr)
|
||||
timestamps_all.append({
|
||||
"segment_index": 0,
|
||||
"phoneme_spans": ph_spans,
|
||||
"char_spans": char_spans,
|
||||
"word_spans": word_spans,
|
||||
"segment_start_s": 0.0,
|
||||
"segment_end_s": audio_len_s,
|
||||
"text": norm_text_seg,
|
||||
})
|
||||
|
||||
# srt
|
||||
import soundfile as sf
|
||||
bio = io.BytesIO()
|
||||
sf.write(bio, (audio * 32767).astype("int16"), sr, format="WAV", subtype="PCM_16")
|
||||
wav_b64 = base64.b64encode(bio.getvalue()).decode("ascii")
|
||||
|
||||
def _fmt_srt_time(t):
|
||||
h = int(t // 3600)
|
||||
m = int((t % 3600) // 60)
|
||||
s = int(t % 60)
|
||||
ms = int(round((t - int(t)) * 1000))
|
||||
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||||
|
||||
srt_b64 = None
|
||||
if timestamps_all:
|
||||
srt_lines = []
|
||||
idx_counter = 1
|
||||
for rec in timestamps_all:
|
||||
if rec.get("char_spans"):
|
||||
for c in rec["char_spans"]:
|
||||
cs = rec.get("char_spans") or []
|
||||
ws = rec.get("word_spans") or []
|
||||
entries = []
|
||||
for c in cs:
|
||||
entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]})
|
||||
for w in ws:
|
||||
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
|
||||
if entries:
|
||||
entries.sort(key=lambda x: x["start"])
|
||||
for e in entries:
|
||||
srt_lines.append(str(idx_counter))
|
||||
srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}")
|
||||
srt_lines.append(c.get("char", ""))
|
||||
srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}")
|
||||
srt_lines.append(e["text"])
|
||||
srt_lines.append("")
|
||||
idx_counter += 1
|
||||
else:
|
||||
st = rec.get("segment_start_s")
|
||||
ed = rec.get("segment_end_s")
|
||||
if st is not None and ed is not None:
|
||||
srt_lines.append(str(idx_counter))
|
||||
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
|
||||
srt_lines.append(rec.get("text", ""))
|
||||
srt_lines.append("")
|
||||
idx_counter += 1
|
||||
srt_text = "\n".join(srt_lines)
|
||||
srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii")
|
||||
|
||||
# pack wav
|
||||
import base64
|
||||
buf = BytesIO()
|
||||
sf.write(buf, audio_np, sr_hz, format="WAV")
|
||||
wav_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
return JSONResponse({"code": 0, "sr": sr_hz, "audio_wav_base64": wav_b64, "timestamps": timestamps_all, "srt": srt_text})
|
||||
timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii")
|
||||
|
||||
|
||||
@app.post("/tts_json")
|
||||
async def tts_json_post(request: Request):
|
||||
body = await request.json()
|
||||
return synthesize_json(
|
||||
body.get("refer_wav_path"),
|
||||
body.get("prompt_text"),
|
||||
body.get("prompt_language"),
|
||||
body.get("text"),
|
||||
body.get("text_language"),
|
||||
body.get("top_k", 15),
|
||||
body.get("top_p", 0.6),
|
||||
body.get("temperature", 0.6),
|
||||
body.get("speed", 1.0),
|
||||
body.get("inp_refs", []),
|
||||
body.get("sample_steps", 32),
|
||||
body.get("if_sr", False),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/tts_json")
|
||||
async def tts_json_get(
|
||||
refer_wav_path: str,
|
||||
prompt_text: str,
|
||||
prompt_language: str,
|
||||
text: str,
|
||||
text_language: str,
|
||||
top_k: int = 15,
|
||||
top_p: float = 0.6,
|
||||
temperature: float = 0.6,
|
||||
speed: float = 1.0,
|
||||
inp_refs: list = Query(default=[]),
|
||||
sample_steps: int = 32,
|
||||
if_sr: bool = False,
|
||||
):
|
||||
return synthesize_json(
|
||||
refer_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
speed,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
)
|
||||
resp = {
|
||||
"sr": sr,
|
||||
"audio_wav_base64": wav_b64,
|
||||
"timestamps": timestamps_all,
|
||||
"timestamps_json_base64": timestamps_json_b64,
|
||||
}
|
||||
if srt_b64:
|
||||
resp["srt_base64"] = srt_b64
|
||||
return JSONResponse({"code": 0, "data": resp})
|
||||
except Exception as e:
|
||||
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user