mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 05:36:34 +08:00
Fix the timestamp processing logic for phonemes, characters, and words
This commit is contained in:
parent
705df4c414
commit
2ff9a1533f
@ -596,6 +596,7 @@ def get_first(text):
|
|||||||
|
|
||||||
|
|
||||||
from text import chinese
|
from text import chinese
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
def get_phones_and_bert(text, language, version, final=False):
|
def get_phones_and_bert(text, language, version, final=False):
|
||||||
@ -850,6 +851,8 @@ def get_tts_wav(
|
|||||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||||
|
|
||||||
timestamps_all = []
|
timestamps_all = []
|
||||||
|
elapsed_s = 0.0
|
||||||
|
sr_hz = int(hps.data.sampling_rate)
|
||||||
for i_text, text in enumerate(texts):
|
for i_text, text in enumerate(texts):
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if len(text.strip()) == 0:
|
if len(text.strip()) == 0:
|
||||||
@ -974,6 +977,71 @@ def get_tts_wav(
|
|||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
char_spans[-1]["end_s"] = span["end_s"]
|
char_spans[-1]["end_s"] = span["end_s"]
|
||||||
|
# post-merge by char_index across the whole segment to remove jitter fragments
|
||||||
|
if char_spans:
|
||||||
|
# group by char_index
|
||||||
|
groups = {}
|
||||||
|
for cs in char_spans:
|
||||||
|
groups.setdefault(cs["char_index"], []).append(cs)
|
||||||
|
merged = []
|
||||||
|
gap_merge_s = 0.08
|
||||||
|
# adaptive minimal duration: at least one frame, but not lower than 15ms
|
||||||
|
min_dur_s = max(0.015, frame_time)
|
||||||
|
for ci, lst in groups.items():
|
||||||
|
lst = sorted(lst, key=lambda x: x["start_s"])
|
||||||
|
cur = None
|
||||||
|
for it in lst:
|
||||||
|
if cur is None:
|
||||||
|
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
|
||||||
|
else:
|
||||||
|
if it["start_s"] - cur["end_s"] <= gap_merge_s:
|
||||||
|
if it["end_s"] > cur["end_s"]:
|
||||||
|
cur["end_s"] = it["end_s"]
|
||||||
|
else:
|
||||||
|
if cur["end_s"] - cur["start_s"] >= min_dur_s:
|
||||||
|
merged.append(cur)
|
||||||
|
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
|
||||||
|
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
|
||||||
|
merged.append(cur)
|
||||||
|
# sort merged by time
|
||||||
|
char_spans = sorted(merged, key=lambda x: x["start_s"])
|
||||||
|
# remap normalized chars back to original input text to avoid spurious '.'/'?'
|
||||||
|
def _build_norm_to_orig_map(orig, norm):
|
||||||
|
all_punc_local = set(punctuation).union(set(splits))
|
||||||
|
mapping = [-1] * len(norm)
|
||||||
|
o = 0
|
||||||
|
n = 0
|
||||||
|
while n < len(norm) and o < len(orig):
|
||||||
|
if norm[n] == orig[o]:
|
||||||
|
mapping[n] = o
|
||||||
|
n += 1
|
||||||
|
o += 1
|
||||||
|
else:
|
||||||
|
# skip spaces/punctuations on either side
|
||||||
|
if orig[o].isspace() or orig[o] in all_punc_local:
|
||||||
|
o += 1
|
||||||
|
elif norm[n].isspace() or norm[n] in all_punc_local:
|
||||||
|
n += 1
|
||||||
|
else:
|
||||||
|
# characters differ (e.g., compatibility form). Advance normalized index.
|
||||||
|
n += 1
|
||||||
|
return mapping
|
||||||
|
norm2orig = _build_norm_to_orig_map(text, norm_text_seg)
|
||||||
|
all_punc_local = set(punctuation).union(set(splits))
|
||||||
|
remapped = []
|
||||||
|
for cs in char_spans:
|
||||||
|
ci = cs["char_index"]
|
||||||
|
ch_norm = cs.get("char", "")
|
||||||
|
oi = norm2orig[ci] if ci < len(norm2orig) else -1
|
||||||
|
if oi != -1:
|
||||||
|
cs["char"] = text[oi]
|
||||||
|
remapped.append(cs)
|
||||||
|
else:
|
||||||
|
# drop normalized-only punctuations/spaces not present in original
|
||||||
|
if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()):
|
||||||
|
continue
|
||||||
|
remapped.append(cs)
|
||||||
|
char_spans = remapped
|
||||||
# word spans
|
# word spans
|
||||||
word_spans = []
|
word_spans = []
|
||||||
if text_language == "en":
|
if text_language == "en":
|
||||||
@ -1002,12 +1070,32 @@ def get_tts_wav(
|
|||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
word_spans[-1]["end_s"] = span["end_s"]
|
word_spans[-1]["end_s"] = span["end_s"]
|
||||||
|
# add absolute offsets and record segment timing
|
||||||
|
audio_len_s = float(audio.shape[0]) / sr_hz
|
||||||
|
if 'ph_spans' in locals() and ph_spans:
|
||||||
|
for d in ph_spans:
|
||||||
|
d["start_s"] += elapsed_s
|
||||||
|
d["end_s"] += elapsed_s
|
||||||
|
if char_spans:
|
||||||
|
for d in char_spans:
|
||||||
|
d["start_s"] += elapsed_s
|
||||||
|
d["end_s"] += elapsed_s
|
||||||
|
if word_spans:
|
||||||
|
for d in word_spans:
|
||||||
|
d["start_s"] += elapsed_s
|
||||||
|
d["end_s"] += elapsed_s
|
||||||
|
seg_start_s = elapsed_s
|
||||||
|
seg_end_s = elapsed_s + audio_len_s
|
||||||
timestamps_all.append({
|
timestamps_all.append({
|
||||||
"segment_index": i_text,
|
"segment_index": i_text,
|
||||||
"phoneme_spans": ph_spans,
|
"phoneme_spans": ph_spans,
|
||||||
"char_spans": char_spans,
|
"char_spans": char_spans,
|
||||||
"word_spans": word_spans,
|
"word_spans": word_spans,
|
||||||
|
"segment_start_s": seg_start_s,
|
||||||
|
"segment_end_s": seg_end_s,
|
||||||
|
"text": norm_text_seg,
|
||||||
})
|
})
|
||||||
|
elapsed_s += audio_len_s + float(pause_second)
|
||||||
else:
|
else:
|
||||||
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
||||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||||
@ -1086,8 +1174,62 @@ def get_tts_wav(
|
|||||||
audio_opt /= max_audio
|
audio_opt /= max_audio
|
||||||
else:
|
else:
|
||||||
audio_opt = audio_opt.cpu().detach().numpy()
|
audio_opt = audio_opt.cpu().detach().numpy()
|
||||||
# Return audio and timestamps for UI consumption: ((sr, audio), timestamps)
|
# Build SRT file content from timestamps_all
|
||||||
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all
|
def _fmt_srt_time(t):
|
||||||
|
h = int(t // 3600)
|
||||||
|
m = int((t % 3600) // 60)
|
||||||
|
s = int(t % 60)
|
||||||
|
ms = int(round((t - int(t)) * 1000))
|
||||||
|
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||||||
|
|
||||||
|
srt_lines = []
|
||||||
|
idx_counter = 1
|
||||||
|
for rec in timestamps_all:
|
||||||
|
# 优先按字分段
|
||||||
|
if rec.get("char_spans") and len(rec["char_spans"]):
|
||||||
|
for c in rec["char_spans"]:
|
||||||
|
st = c["start_s"]
|
||||||
|
ed = c["end_s"]
|
||||||
|
txt = c.get("char", "")
|
||||||
|
srt_lines.append(str(idx_counter))
|
||||||
|
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
|
||||||
|
srt_lines.append(txt)
|
||||||
|
srt_lines.append("")
|
||||||
|
idx_counter += 1
|
||||||
|
continue
|
||||||
|
# 次选按词(英文)
|
||||||
|
if rec.get("word_spans") and len(rec["word_spans"]):
|
||||||
|
for w in rec["word_spans"]:
|
||||||
|
st = w["start_s"]
|
||||||
|
ed = w["end_s"]
|
||||||
|
txt = w.get("word", "")
|
||||||
|
srt_lines.append(str(idx_counter))
|
||||||
|
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
|
||||||
|
srt_lines.append(txt)
|
||||||
|
srt_lines.append("")
|
||||||
|
idx_counter += 1
|
||||||
|
continue
|
||||||
|
# 最后按整段兜底
|
||||||
|
st = rec.get("segment_start_s")
|
||||||
|
ed = rec.get("segment_end_s")
|
||||||
|
text_line = rec.get("text", "")
|
||||||
|
if st is not None and ed is not None:
|
||||||
|
srt_lines.append(str(idx_counter))
|
||||||
|
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
|
||||||
|
srt_lines.append(text_line)
|
||||||
|
srt_lines.append("")
|
||||||
|
idx_counter += 1
|
||||||
|
|
||||||
|
srt_path = None
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".srt", mode="w", encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(srt_lines))
|
||||||
|
srt_path = f.name
|
||||||
|
except Exception:
|
||||||
|
srt_path = None
|
||||||
|
|
||||||
|
# Return audio, timestamps and SRT path for UI
|
||||||
|
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path
|
||||||
|
|
||||||
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
@ -1375,6 +1517,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
|||||||
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
||||||
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
||||||
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)"))
|
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)"))
|
||||||
|
srt_file = gr.File(label=i18n("下载SRT字幕"))
|
||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
get_tts_wav,
|
get_tts_wav,
|
||||||
@ -1396,7 +1539,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
|||||||
if_sr_Checkbox,
|
if_sr_Checkbox,
|
||||||
pause_second_slider,
|
pause_second_slider,
|
||||||
],
|
],
|
||||||
[output, timestamps_box],
|
[output, timestamps_box, srt_file],
|
||||||
)
|
)
|
||||||
SoVITS_dropdown.change(
|
SoVITS_dropdown.change(
|
||||||
change_sovits_weights,
|
change_sovits_weights,
|
||||||
|
297
api.py
297
api.py
@ -1385,5 +1385,302 @@ async def tts_endpoint(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_srt_time(t):
|
||||||
|
h = int(t // 3600)
|
||||||
|
m = int((t % 3600) // 60)
|
||||||
|
s = int(t % 60)
|
||||||
|
ms = int(round((t - int(t)) * 1000))
|
||||||
|
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_norm_to_orig_map(orig, norm):
|
||||||
|
all_punc_local = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||||
|
mapping = [-1] * len(norm)
|
||||||
|
o = 0
|
||||||
|
n = 0
|
||||||
|
while n < len(norm) and o < len(orig):
|
||||||
|
if norm[n] == orig[o]:
|
||||||
|
mapping[n] = o
|
||||||
|
n += 1
|
||||||
|
o += 1
|
||||||
|
else:
|
||||||
|
if orig[o].isspace() or orig[o] in all_punc_local:
|
||||||
|
o += 1
|
||||||
|
elif norm[n].isspace() or norm[n] in all_punc_local:
|
||||||
|
n += 1
|
||||||
|
else:
|
||||||
|
n += 1
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_json(
|
||||||
|
refer_wav_path,
|
||||||
|
prompt_text,
|
||||||
|
prompt_language,
|
||||||
|
text,
|
||||||
|
text_language,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
temperature,
|
||||||
|
speed,
|
||||||
|
inp_refs,
|
||||||
|
sample_steps,
|
||||||
|
if_sr,
|
||||||
|
):
|
||||||
|
infer_sovits = speaker_list["default"].sovits
|
||||||
|
vq_model = infer_sovits.vq_model
|
||||||
|
hps = infer_sovits.hps
|
||||||
|
version = vq_model.version
|
||||||
|
|
||||||
|
if version in {"v3", "v4"}:
|
||||||
|
return JSONResponse({"code": 400, "message": "v3/v4 暂未提供时间戳JSON接口"}, status_code=400)
|
||||||
|
|
||||||
|
# prepare refer features (same as handle)
|
||||||
|
dtype = torch.float16 if is_half else torch.float32
|
||||||
|
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half else np.float32)
|
||||||
|
with torch.no_grad():
|
||||||
|
wav16k, sr = librosa.load(refer_wav_path, sr=16000)
|
||||||
|
wav16k = torch.from_numpy(wav16k)
|
||||||
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||||
|
if is_half:
|
||||||
|
wav16k = wav16k.half().to(device)
|
||||||
|
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||||
|
else:
|
||||||
|
wav16k = wav16k.to(device)
|
||||||
|
zero_wav_torch = zero_wav_torch.to(device)
|
||||||
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||||
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)
|
||||||
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
|
prompt_semantic = codes[0, 0]
|
||||||
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
is_v2pro = version in {"v2Pro", "v2ProPlus"}
|
||||||
|
refers = []
|
||||||
|
if is_v2pro:
|
||||||
|
sv_emb = []
|
||||||
|
if sv_cn_model == None:
|
||||||
|
init_sv_cn()
|
||||||
|
spec, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro)
|
||||||
|
refers = [spec]
|
||||||
|
if is_v2pro:
|
||||||
|
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
|
||||||
|
|
||||||
|
# text frontend
|
||||||
|
prompt_language = dict_language[prompt_language.lower()]
|
||||||
|
text_language = dict_language[text_language.lower()]
|
||||||
|
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||||
|
texts = text.strip("\n").split("\n")
|
||||||
|
|
||||||
|
# iterate
|
||||||
|
audio_segments = []
|
||||||
|
timestamps_all = []
|
||||||
|
sr_hz = int(hps.data.sampling_rate)
|
||||||
|
elapsed_s = 0.0
|
||||||
|
for seg in texts:
|
||||||
|
if only_punc(seg):
|
||||||
|
continue
|
||||||
|
if seg[-1] not in splits:
|
||||||
|
seg += "。" if text_language != "en" else "."
|
||||||
|
phones2, bert2, norm_text2 = get_phones_and_bert(seg, text_language, version)
|
||||||
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||||
|
bert = bert.to(device).unsqueeze(0)
|
||||||
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
pred_semantic, idx = speaker_list["default"].gpt.t2s_model.model.infer_panel(
|
||||||
|
all_phoneme_ids, all_phoneme_len, prompt, bert, top_k=top_k, top_p=top_p, temperature=temperature, early_stop_num=hz * speaker_list["default"].gpt.max_sec,
|
||||||
|
)
|
||||||
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||||
|
phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||||
|
if is_v2pro:
|
||||||
|
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||||
|
pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o, attn, y_mask = vq_model.decode_with_alignment(
|
||||||
|
pred_semantic, phones2_tensor, refers, speed=speed
|
||||||
|
)
|
||||||
|
audio = o[0][0].detach().cpu().numpy()
|
||||||
|
# timestamps
|
||||||
|
frame_time = 0.02 / max(float(speed), 1e-6)
|
||||||
|
ph_spans = []
|
||||||
|
if attn is not None:
|
||||||
|
attn_mean = attn.mean(dim=1)[0]
|
||||||
|
assign = attn_mean.argmax(dim=-1)
|
||||||
|
if assign.numel() > 0:
|
||||||
|
start_f = 0
|
||||||
|
cur_ph = int(assign[0].item())
|
||||||
|
for f in range(1, assign.shape[0]):
|
||||||
|
ph = int(assign[f].item())
|
||||||
|
if ph != cur_ph:
|
||||||
|
ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": f * frame_time})
|
||||||
|
start_f = f
|
||||||
|
cur_ph = ph
|
||||||
|
ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": assign.shape[0] * frame_time})
|
||||||
|
|
||||||
|
# char aggregation
|
||||||
|
_, word2ph, norm_text_seg = clean_text_inf(seg, text_language, version)
|
||||||
|
char_spans = []
|
||||||
|
if word2ph:
|
||||||
|
ph_to_char = []
|
||||||
|
for ch_idx, repeat in enumerate(word2ph):
|
||||||
|
ph_to_char += [ch_idx] * repeat
|
||||||
|
if ph_spans and ph_to_char:
|
||||||
|
for span in ph_spans:
|
||||||
|
ph_idx = span["phoneme_id"]
|
||||||
|
if 0 <= ph_idx < len(ph_to_char):
|
||||||
|
char_idx = ph_to_char[ph_idx]
|
||||||
|
char_spans.append({"char_index": char_idx, "char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", "start_s": span["start_s"], "end_s": span["end_s"]})
|
||||||
|
# merge by char_index
|
||||||
|
if char_spans:
|
||||||
|
groups = {}
|
||||||
|
for cs in char_spans:
|
||||||
|
groups.setdefault(cs["char_index"], []).append(cs)
|
||||||
|
merged = []
|
||||||
|
gap_merge_s = 0.08
|
||||||
|
min_dur_s = max(0.015, frame_time)
|
||||||
|
for ci, lst in groups.items():
|
||||||
|
lst = sorted(lst, key=lambda x: x["start_s"])
|
||||||
|
cur = None
|
||||||
|
for it in lst:
|
||||||
|
if cur is None:
|
||||||
|
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
|
||||||
|
else:
|
||||||
|
if it["start_s"] - cur["end_s"] <= gap_merge_s:
|
||||||
|
if it["end_s"] > cur["end_s"]:
|
||||||
|
cur["end_s"] = it["end_s"]
|
||||||
|
else:
|
||||||
|
if cur["end_s"] - cur["start_s"] >= min_dur_s:
|
||||||
|
merged.append(cur)
|
||||||
|
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
|
||||||
|
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
|
||||||
|
merged.append(cur)
|
||||||
|
# remap to original input text
|
||||||
|
norm2orig = _build_norm_to_orig_map(seg, norm_text_seg)
|
||||||
|
remapped = []
|
||||||
|
punc = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||||
|
for cs in merged:
|
||||||
|
ci = cs["char_index"]
|
||||||
|
oi = norm2orig[ci] if ci < len(norm2orig) else -1
|
||||||
|
ch_norm = cs.get("char", "")
|
||||||
|
if oi != -1:
|
||||||
|
cs["char"] = seg[oi]
|
||||||
|
remapped.append(cs)
|
||||||
|
else:
|
||||||
|
if ch_norm and (ch_norm in punc or ch_norm.isspace()):
|
||||||
|
continue
|
||||||
|
remapped.append(cs)
|
||||||
|
char_spans = sorted(remapped, key=lambda x: x["start_s"]) if remapped else []
|
||||||
|
|
||||||
|
# offset and record
|
||||||
|
audio_len_s = float(audio.shape[0]) / sr_hz
|
||||||
|
for coll in (char_spans,):
|
||||||
|
for d in coll:
|
||||||
|
d["start_s"] += elapsed_s
|
||||||
|
d["end_s"] += elapsed_s
|
||||||
|
timestamps_all.append({
|
||||||
|
"segment_index": len(timestamps_all),
|
||||||
|
"char_spans": char_spans,
|
||||||
|
"text": norm_text_seg,
|
||||||
|
"segment_start_s": elapsed_s,
|
||||||
|
"segment_end_s": elapsed_s + audio_len_s,
|
||||||
|
})
|
||||||
|
elapsed_s += audio_len_s + 0.3
|
||||||
|
audio_segments.append(audio)
|
||||||
|
|
||||||
|
# concatenate audio
|
||||||
|
if len(audio_segments) == 0:
|
||||||
|
return JSONResponse({"code": 400, "message": "无有效文本"}, status_code=400)
|
||||||
|
pad = np.zeros(int(sr_hz * 0.3), dtype=audio_segments[0].dtype)
|
||||||
|
out = []
|
||||||
|
for i, a in enumerate(audio_segments):
|
||||||
|
out.append(a)
|
||||||
|
if i < len(audio_segments) - 1:
|
||||||
|
out.append(pad)
|
||||||
|
audio_np = np.concatenate(out, 0)
|
||||||
|
mx = np.abs(audio_np).max()
|
||||||
|
if mx > 1:
|
||||||
|
audio_np = audio_np / mx
|
||||||
|
|
||||||
|
# srt
|
||||||
|
srt_lines = []
|
||||||
|
idx_counter = 1
|
||||||
|
for rec in timestamps_all:
|
||||||
|
if rec.get("char_spans"):
|
||||||
|
for c in rec["char_spans"]:
|
||||||
|
srt_lines.append(str(idx_counter))
|
||||||
|
srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}")
|
||||||
|
srt_lines.append(c.get("char", ""))
|
||||||
|
srt_lines.append("")
|
||||||
|
idx_counter += 1
|
||||||
|
else:
|
||||||
|
st = rec.get("segment_start_s")
|
||||||
|
ed = rec.get("segment_end_s")
|
||||||
|
if st is not None and ed is not None:
|
||||||
|
srt_lines.append(str(idx_counter))
|
||||||
|
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
|
||||||
|
srt_lines.append(rec.get("text", ""))
|
||||||
|
srt_lines.append("")
|
||||||
|
idx_counter += 1
|
||||||
|
srt_text = "\n".join(srt_lines)
|
||||||
|
|
||||||
|
# pack wav
|
||||||
|
import base64
|
||||||
|
buf = BytesIO()
|
||||||
|
sf.write(buf, audio_np, sr_hz, format="WAV")
|
||||||
|
wav_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||||
|
return JSONResponse({"code": 0, "sr": sr_hz, "audio_wav_base64": wav_b64, "timestamps": timestamps_all, "srt": srt_text})
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/tts_json")
|
||||||
|
async def tts_json_post(request: Request):
|
||||||
|
body = await request.json()
|
||||||
|
return synthesize_json(
|
||||||
|
body.get("refer_wav_path"),
|
||||||
|
body.get("prompt_text"),
|
||||||
|
body.get("prompt_language"),
|
||||||
|
body.get("text"),
|
||||||
|
body.get("text_language"),
|
||||||
|
body.get("top_k", 15),
|
||||||
|
body.get("top_p", 0.6),
|
||||||
|
body.get("temperature", 0.6),
|
||||||
|
body.get("speed", 1.0),
|
||||||
|
body.get("inp_refs", []),
|
||||||
|
body.get("sample_steps", 32),
|
||||||
|
body.get("if_sr", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tts_json")
|
||||||
|
async def tts_json_get(
|
||||||
|
refer_wav_path: str,
|
||||||
|
prompt_text: str,
|
||||||
|
prompt_language: str,
|
||||||
|
text: str,
|
||||||
|
text_language: str,
|
||||||
|
top_k: int = 15,
|
||||||
|
top_p: float = 0.6,
|
||||||
|
temperature: float = 0.6,
|
||||||
|
speed: float = 1.0,
|
||||||
|
inp_refs: list = Query(default=[]),
|
||||||
|
sample_steps: int = 32,
|
||||||
|
if_sr: bool = False,
|
||||||
|
):
|
||||||
|
return synthesize_json(
|
||||||
|
refer_wav_path,
|
||||||
|
prompt_text,
|
||||||
|
prompt_language,
|
||||||
|
text,
|
||||||
|
text_language,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
temperature,
|
||||||
|
speed,
|
||||||
|
inp_refs,
|
||||||
|
sample_steps,
|
||||||
|
if_sr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host=host, port=port, workers=1)
|
uvicorn.run(app, host=host, port=port, workers=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user