From 19ca3f3f6a5c753e8c225252668f471ac89c063a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E8=8F=9C=E5=B7=A5=E5=8E=821145=E5=8F=B7=E5=91=98?= =?UTF-8?q?=E5=B7=A5?= <114749500+baicai-1145@users.noreply.github.com> Date: Mon, 22 Sep 2025 07:43:59 +0800 Subject: [PATCH] Supports phoneme and word-level timestamp output for multilingual text --- GPT_SoVITS/inference_webui.py | 379 ++-- GPT_SoVITS/text/cleaner.py | 20 +- GPT_SoVITS/text/english.py | 26 + GPT_SoVITS/text/japanese.py | 24 + GPT_SoVITS/text/korean.py | 699 +++---- api.py | 3494 +++++++++++++++++---------------- 6 files changed, 2473 insertions(+), 2169 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index f1cadd76..53047cb4 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -932,7 +932,40 @@ def get_tts_wav( # derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640) if last_attn is not None: attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text] - assign = attn_heads_mean.argmax(dim=-1) # [T_ssl] + # Viterbi monotonic alignment (stay or advance by 1) + def _viterbi_monotonic(p): + T, N = p.shape + if T < N: + return p.argmax(dim=-1) + eps = 1e-8 + cost = -torch.log(p + eps) + dp = torch.empty((T, N), dtype=cost.dtype, device=cost.device) + prev = torch.zeros((T, N), dtype=torch.uint8, device=cost.device) + dp[0, 0] = cost[0, 0] + if N > 1: + dp[0, 1:] = float("inf") + for t_i in range(1, T): + # j=0: only stay + dp[t_i, 0] = dp[t_i - 1, 0] + cost[t_i, 0] + prev[t_i, 0] = 0 + if N > 1: + stay = dp[t_i - 1, 1:] + cost[t_i, 1:] + move = dp[t_i - 1, :-1] + cost[t_i, 1:] + better_move = move < stay + dp[t_i, 1:] = torch.where(better_move, move, stay) + prev[t_i, 1:] = better_move.to(torch.uint8) + # backtrack from (T-1, N-1) + j = N - 1 + assign_bt = torch.empty(T, dtype=torch.long, device=cost.device) + for t_i in range(T - 1, -1, -1): + assign_bt[t_i] = j + if t_i > 0 and prev[t_i, j] == 1: + j = j - 1 + if j < 0: + j = 0 + return assign_bt + + assign = _viterbi_monotonic(attn_heads_mean) frame_time = 0.02 / max(speed, 1e-6) # collapse consecutive frames pointing to same phoneme id ph_spans = [] @@ -954,122 +987,186 @@ def get_tts_wav( "start_s": start_f * frame_time, "end_s": assign.shape[0] * frame_time, }) - # char/word aggregation - # obtain word2ph for current text segment - _, word2ph, norm_text_seg = clean_text_inf(text, text_language, version) - # char spans (for zh/yue where word2ph is char-based) - char_spans = [] - if word2ph: - ph_to_char = [] - for ch_idx, repeat in enumerate(word2ph): - ph_to_char += [ch_idx] * repeat - if ph_spans and ph_to_char: - for span in ph_spans: - ph_idx = span["phoneme_id"] - if 0 <= ph_idx < len(ph_to_char): - char_idx = ph_to_char[ph_idx] - if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx: - char_spans.append({ - "char_index": char_idx, - "char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", - "start_s": span["start_s"], - "end_s": span["end_s"], - }) - else: - char_spans[-1]["end_s"] = span["end_s"] - # post-merge by char_index across the whole segment to remove jitter fragments - if char_spans: - # group by char_index - groups = {} - for cs in char_spans: - groups.setdefault(cs["char_index"], []).append(cs) - merged = [] - gap_merge_s = 0.08 - # adaptive minimal duration: at least one frame, but not lower than 15ms - min_dur_s = max(0.015, frame_time) - for ci, lst in groups.items(): - lst = sorted(lst, key=lambda x: x["start_s"]) - cur = None - for it in lst: - if cur is None: - cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} - else: - if it["start_s"] - cur["end_s"] <= gap_merge_s: - if it["end_s"] > cur["end_s"]: - cur["end_s"] = it["end_s"] - else: - if cur["end_s"] - cur["start_s"] >= min_dur_s: - merged.append(cur) - cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} - if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s: - merged.append(cur) - # sort merged by time - char_spans = sorted(merged, key=lambda x: x["start_s"]) - # remap normalized chars back to original input text to avoid spurious '.'/'?' - def _build_norm_to_orig_map(orig, norm): - all_punc_local = set(punctuation).union(set(splits)) - mapping = [-1] * len(norm) - o = 0 - n = 0 - while n < len(norm) and o < len(orig): - if norm[n] == orig[o]: - mapping[n] = o - n += 1 - o += 1 - else: - # skip spaces/punctuations on either side - if orig[o].isspace() or orig[o] in all_punc_local: - o += 1 - elif norm[n].isspace() or norm[n] in all_punc_local: - n += 1 - else: - # characters differ (e.g., compatibility form). Advance normalized index. - n += 1 - return mapping - norm2orig = _build_norm_to_orig_map(text, norm_text_seg) - all_punc_local = set(punctuation).union(set(splits)) - remapped = [] - for cs in char_spans: - ci = cs["char_index"] - ch_norm = cs.get("char", "") - oi = norm2orig[ci] if ci < len(norm2orig) else -1 - if oi != -1: - cs["char"] = text[oi] - remapped.append(cs) - else: - # drop normalized-only punctuations/spaces not present in original - if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()): + # char/word aggregation for multi-language text + def _build_mixed_mappings(_text, _ui_lang, _version): + # replicate segmentation logic from get_phones_and_bert + _text = re.sub(r' {2,}', ' ', _text) + textlist = [] + langlist = [] + if _ui_lang == "all_zh": + for tmp in LangSegmenter.getTexts(_text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "all_yue": + for tmp in LangSegmenter.getTexts(_text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "all_ja": + for tmp in LangSegmenter.getTexts(_text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "all_ko": + for tmp in LangSegmenter.getTexts(_text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "en": + langlist.append("en") + textlist.append(_text) + elif _ui_lang == "auto": + for tmp in LangSegmenter.getTexts(_text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "auto_yue": + for tmp in LangSegmenter.getTexts(_text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(_text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] continue - remapped.append(cs) - char_spans = remapped - # word spans - word_spans = [] - if text_language == "en": - # build phoneme-to-word map using dictionary per-word g2p - try: - from text.english import g2p as g2p_en - except Exception: - g2p_en = None - words = norm_text_seg.split() + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + langlist.append(_ui_lang) + textlist.append(tmp["text"]) + + # aggregate mappings + ph_to_char = [] ph_to_word = [] - if g2p_en: - for w_idx, w in enumerate(words): - phs_w = g2p_en(w) - ph_to_word += [w_idx] * len(phs_w) - if ph_spans and ph_to_word: - for span in ph_spans: - ph_idx = span["phoneme_id"] - if 0 <= ph_idx < len(ph_to_word): - wi = ph_to_word[ph_idx] - if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi: - word_spans.append({ - "word_index": wi, - "word": words[wi] if wi < len(words) else "", - "start_s": span["start_s"], - "end_s": span["end_s"], - }) + word_tokens = [] + norm_text_agg = [] + import re as _re + for seg_text, seg_lang in zip(textlist, langlist): + seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version) + norm_text_agg.append(seg_norm) + if seg_lang in {"zh", "yue", "ja"} and seg_word2ph: + # char-based + char_base_idx = len("".join(norm_text_agg[:-1])) + for ch_idx, cnt in enumerate(seg_word2ph): + global_char_idx = char_base_idx + ch_idx + ph_to_char += [global_char_idx] * cnt + ph_to_word += [-1] * len(seg_phones) + elif seg_lang in {"en", "ko"} and seg_word2ph: + # word-based + tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)] + base_word_idx = len(word_tokens) + t_idx = 0 + for cnt in seg_word2ph: + if t_idx < len(tokens_seg): + ph_to_word += [base_word_idx + t_idx] * cnt + t_idx += 1 + ph_to_char += [-1] * len(seg_phones) + word_tokens.extend(tokens_seg) + else: + # unknown mapping; fill with -1 + ph_to_char += [-1] * len(seg_phones) + ph_to_word += [-1] * len(seg_phones) + norm_text_agg = "".join(norm_text_agg) + return norm_text_agg, ph_to_char, ph_to_word, word_tokens + + norm_text_seg, ph_to_char_map, ph_to_word_map, word_tokens = _build_mixed_mappings(text, text_language, version) + + # char spans (for zh/yue/ja parts only) + char_spans = [] + if ph_spans and ph_to_char_map and len(ph_to_char_map) >= 1: + for span in ph_spans: + ph_idx = span["phoneme_id"] + if 0 <= ph_idx < len(ph_to_char_map): + ci = ph_to_char_map[ph_idx] + if ci == -1: + continue + if len(char_spans) == 0 or char_spans[-1]["char_index"] != ci: + char_spans.append({ + "char_index": ci, + "char": norm_text_seg[ci] if ci < len(norm_text_seg) else "", + "start_s": span["start_s"], + "end_s": span["end_s"], + }) + else: + char_spans[-1]["end_s"] = span["end_s"] + # post-merge and remap for chars + if char_spans: + groups = {} + for cs in char_spans: + groups.setdefault(cs["char_index"], []).append(cs) + merged = [] + gap_merge_s = 0.08 + min_dur_s = max(0.015, frame_time) + for ci, lst in groups.items(): + lst = sorted(lst, key=lambda x: x["start_s"]) + cur = None + for it in lst: + if cur is None: + cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} + else: + if it["start_s"] - cur["end_s"] <= gap_merge_s: + if it["end_s"] > cur["end_s"]: + cur["end_s"] = it["end_s"] else: - word_spans[-1]["end_s"] = span["end_s"] + if cur["end_s"] - cur["start_s"] >= min_dur_s: + merged.append(cur) + cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} + if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s: + merged.append(cur) + char_spans = sorted(merged, key=lambda x: x["start_s"]) + def _build_norm_to_orig_map(orig, norm): + all_punc_local = set(punctuation).union(set(splits)) + mapping = [-1] * len(norm) + o = 0 + n = 0 + while n < len(norm) and o < len(orig): + if norm[n] == orig[o]: + mapping[n] = o + n += 1 + o += 1 + else: + if orig[o].isspace() or orig[o] in all_punc_local: + o += 1 + elif norm[n].isspace() or norm[n] in all_punc_local: + n += 1 + else: + n += 1 + return mapping + norm2orig = _build_norm_to_orig_map(text, norm_text_seg) + all_punc_local = set(punctuation).union(set(splits)) + remapped = [] + for cs in char_spans: + ci = cs["char_index"] + ch_norm = cs.get("char", "") + oi = norm2orig[ci] if ci < len(norm2orig) else -1 + if oi != -1: + cs["char"] = text[oi] + remapped.append(cs) + else: + if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()): + continue + remapped.append(cs) + char_spans = remapped + + # word spans (for en/ko parts only) + word_spans = [] + if ph_spans and ph_to_word_map and len(ph_to_word_map) >= 1: + for span in ph_spans: + ph_idx = span["phoneme_id"] + if 0 <= ph_idx < len(ph_to_word_map): + wi = ph_to_word_map[ph_idx] + if wi == -1: + continue + if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi: + word_spans.append({ + "word_index": wi, + "word": word_tokens[wi] if wi < len(word_tokens) else "", + "start_s": span["start_s"], + "end_s": span["end_s"], + }) + else: + word_spans[-1]["end_s"] = span["end_s"] # add absolute offsets and record segment timing audio_len_s = float(audio.shape[0]) / sr_hz if 'ph_spans' in locals() and ph_spans: @@ -1185,31 +1282,24 @@ def get_tts_wav( srt_lines = [] idx_counter = 1 for rec in timestamps_all: - # 优先按字分段 - if rec.get("char_spans") and len(rec["char_spans"]): - for c in rec["char_spans"]: - st = c["start_s"] - ed = c["end_s"] - txt = c.get("char", "") + cs = rec.get("char_spans") or [] + ws = rec.get("word_spans") or [] + entries = [] + # 统一时间轴:存在则合并后按开始时间排序输出 + for c in cs: + entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]}) + for w in ws: + entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]}) + if entries: + entries.sort(key=lambda x: x["start"]) + for e in entries: srt_lines.append(str(idx_counter)) - srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") - srt_lines.append(txt) + srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}") + srt_lines.append(e["text"]) srt_lines.append("") idx_counter += 1 continue - # 次选按词(英文) - if rec.get("word_spans") and len(rec["word_spans"]): - for w in rec["word_spans"]: - st = w["start_s"] - ed = w["end_s"] - txt = w.get("word", "") - srt_lines.append(str(idx_counter)) - srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") - srt_lines.append(txt) - srt_lines.append("") - idx_counter += 1 - continue - # 最后按整段兜底 + # 兜底:整段 st = rec.get("segment_start_s") ed = rec.get("segment_end_s") text_line = rec.get("text", "") @@ -1228,8 +1318,17 @@ def get_tts_wav( except Exception: srt_path = None - # Return audio, timestamps and SRT path for UI - yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path + # Also write JSON timestamps to temp file for download + json_path = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w", encoding="utf-8") as fjson: + json.dump(timestamps_all, fjson, ensure_ascii=False, indent=2) + json_path = fjson.name + except Exception: + json_path = None + + # Return audio, timestamps, SRT path and JSON path for UI + yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path, json_path def split(todo_text): @@ -1516,8 +1615,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css with gr.Row(): inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25) output = gr.Audio(label=i18n("输出的语音"), scale=14) - timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)")) + with gr.Accordion(i18n("时间戳(音素/字/词) - 点击展开"), open=False): + timestamps_box = gr.JSON(label=i18n("时间戳JSON预览")) srt_file = gr.File(label=i18n("下载SRT字幕")) + json_file = gr.File(label=i18n("下载时间戳JSON")) inference_button.click( get_tts_wav, @@ -1539,7 +1640,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css if_sr_Checkbox, pause_second_slider, ], - [output, timestamps_box, srt_file], + [output, timestamps_box, srt_file, json_file], ) SoVITS_dropdown.change( change_sovits_weights, diff --git a/GPT_SoVITS/text/cleaner.py b/GPT_SoVITS/text/cleaner.py index 7ba8f376..4176db43 100644 --- a/GPT_SoVITS/text/cleaner.py +++ b/GPT_SoVITS/text/cleaner.py @@ -39,18 +39,24 @@ def clean_text(text, language, version=None): norm_text = language_module.text_normalize(text) else: norm_text = text + if language == "zh" or language == "yue": ########## phones, word2ph = language_module.g2p(norm_text) assert len(phones) == sum(word2ph) assert len(norm_text) == len(word2ph) - elif language == "en": - phones = language_module.g2p(norm_text) - if len(phones) < 4: - phones = [","] + phones - word2ph = None else: - phones = language_module.g2p(norm_text) - word2ph = None + # Try per-language word2ph helpers + if hasattr(language_module, "g2p_with_word2ph"): + try: + phones, word2ph = language_module.g2p_with_word2ph(norm_text, keep_punc=False) + except Exception: + phones = language_module.g2p(norm_text) + word2ph = None + else: + phones = language_module.g2p(norm_text) + word2ph = None + if language == "en" and len(phones) < 4: + phones = [","] + phones phones = ["UNK" if ph not in symbols else ph for ph in phones] return phones, word2ph, norm_text diff --git a/GPT_SoVITS/text/english.py b/GPT_SoVITS/text/english.py index f6c69449..47d2c1b1 100644 --- a/GPT_SoVITS/text/english.py +++ b/GPT_SoVITS/text/english.py @@ -368,6 +368,32 @@ def g2p(text): return replace_phs(phones) +def g2p_with_word2ph(text, keep_punc=False): + """ + Returns (phones, word2ph) for English by whitespace tokenization. + - Tokenize by spaces; for each token, call g2p(token) + - word2ph per token = max(1, num_valid_phones) if keep_punc else skip pure punctuations + """ + tokens = re.split(r"(\s+)", text) + phones_all = [] + word2ph = [] + punc_set = set(punctuation) + for tok in tokens: + if tok.strip() == "": + if keep_punc: + word2ph.append(1) + continue + if all((c in punc_set or c.isspace()) for c in tok): + if keep_punc: + word2ph.append(1) + continue + phs = g2p(tok) + phs_valid = [p for p in phs if p not in punc_set] + phones_all.extend(phs_valid) + word2ph.append(max(1, len(phs_valid))) + return phones_all, word2ph + + if __name__ == "__main__": print(g2p("hello")) print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture."))) diff --git a/GPT_SoVITS/text/japanese.py b/GPT_SoVITS/text/japanese.py index a54d0cf0..6a461d12 100644 --- a/GPT_SoVITS/text/japanese.py +++ b/GPT_SoVITS/text/japanese.py @@ -271,6 +271,30 @@ def g2p(norm_text, with_prosody=True): return phones +# Helper for alignment: build phones and word2ph by per-character g2p (ignoring prosody markers) +def g2p_with_word2ph(text, keep_punc=False): + """ + Returns (phones, word2ph) + - Per-character g2p; ignore prosody markers like '[', ']','^', '$', '#', '_' + - Punctuation counted as 1 if keep_punc else skipped + """ + norm_text = text_normalize(text) + phones_all = [] + word2ph = [] + prosody_markers = {'[', ']', '^', '$', '#', '_'} + punc_set = set(punctuation) + for ch in norm_text: + if ch.isspace() or ch in punc_set: + if keep_punc: + word2ph.append(1) + continue + phs = preprocess_jap(ch, with_prosody=True) + phs = [post_replace_ph(p) for p in phs if p not in prosody_markers and p not in punc_set] + phones_all.extend(phs) + word2ph.append(max(1, len(phs))) + return phones_all, word2ph + + if __name__ == "__main__": phones = g2p("Hello.こんにちは!今日もNiCe天気ですね!tokyotowerに行きましょう!") print(phones) diff --git a/GPT_SoVITS/text/korean.py b/GPT_SoVITS/text/korean.py index 254b05cf..9c31b680 100644 --- a/GPT_SoVITS/text/korean.py +++ b/GPT_SoVITS/text/korean.py @@ -1,337 +1,362 @@ -# reference: https://github.com/ORI-Muchim/MB-iSTFT-VITS-Korean/blob/main/text/korean.py - -import re -from jamo import h2j, j2hcj -import ko_pron -from g2pk2 import G2p - -import importlib -import os - -# 防止win下无法读取模型 -if os.name == "nt": - - class win_G2p(G2p): - def check_mecab(self): - super().check_mecab() - spam_spec = importlib.util.find_spec("eunjeon") - non_found = spam_spec is None - if non_found: - print("you have to install eunjeon. install it...") - else: - installpath = spam_spec.submodule_search_locations[0] - if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): - import sys - from eunjeon import Mecab as _Mecab - - class Mecab(_Mecab): - def get_dicpath(installpath): - if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): - import shutil - - python_dir = os.getcwd() - if installpath[: len(python_dir)].upper() == python_dir.upper(): - dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc") - else: - if not os.path.exists("TEMP"): - os.mkdir("TEMP") - if not os.path.exists(os.path.join("TEMP", "ko")): - os.mkdir(os.path.join("TEMP", "ko")) - if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")): - shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict")) - - shutil.copytree( - os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict") - ) - dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc") - else: - dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc")) - return dicpath - - def __init__(self, dicpath=get_dicpath(installpath)): - super().__init__(dicpath=dicpath) - - sys.modules["eunjeon"].Mecab = Mecab - - G2p = win_G2p - - -from text.symbols2 import symbols - -# This is a list of Korean classifiers preceded by pure Korean numerals. -_korean_classifiers = ( - "군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통" -) - -# List of (hangul, hangul divided) pairs: -_hangul_divided = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - # ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule - # ('ㄵ', 'ㄴㅈ'), - # ('ㄶ', 'ㄴㅎ'), - # ('ㄺ', 'ㄹㄱ'), - # ('ㄻ', 'ㄹㅁ'), - # ('ㄼ', 'ㄹㅂ'), - # ('ㄽ', 'ㄹㅅ'), - # ('ㄾ', 'ㄹㅌ'), - # ('ㄿ', 'ㄹㅍ'), - # ('ㅀ', 'ㄹㅎ'), - # ('ㅄ', 'ㅂㅅ'), - ("ㅘ", "ㅗㅏ"), - ("ㅙ", "ㅗㅐ"), - ("ㅚ", "ㅗㅣ"), - ("ㅝ", "ㅜㅓ"), - ("ㅞ", "ㅜㅔ"), - ("ㅟ", "ㅜㅣ"), - ("ㅢ", "ㅡㅣ"), - ("ㅑ", "ㅣㅏ"), - ("ㅒ", "ㅣㅐ"), - ("ㅕ", "ㅣㅓ"), - ("ㅖ", "ㅣㅔ"), - ("ㅛ", "ㅣㅗ"), - ("ㅠ", "ㅣㅜ"), - ] -] - -# List of (Latin alphabet, hangul) pairs: -_latin_to_hangul = [ - (re.compile("%s" % x[0], re.IGNORECASE), x[1]) - for x in [ - ("a", "에이"), - ("b", "비"), - ("c", "시"), - ("d", "디"), - ("e", "이"), - ("f", "에프"), - ("g", "지"), - ("h", "에이치"), - ("i", "아이"), - ("j", "제이"), - ("k", "케이"), - ("l", "엘"), - ("m", "엠"), - ("n", "엔"), - ("o", "오"), - ("p", "피"), - ("q", "큐"), - ("r", "아르"), - ("s", "에스"), - ("t", "티"), - ("u", "유"), - ("v", "브이"), - ("w", "더블유"), - ("x", "엑스"), - ("y", "와이"), - ("z", "제트"), - ] -] - -# List of (ipa, lazy ipa) pairs: -_ipa_to_lazy_ipa = [ - (re.compile("%s" % x[0], re.IGNORECASE), x[1]) - for x in [ - ("t͡ɕ", "ʧ"), - ("d͡ʑ", "ʥ"), - ("ɲ", "n^"), - ("ɕ", "ʃ"), - ("ʷ", "w"), - ("ɭ", "l`"), - ("ʎ", "ɾ"), - ("ɣ", "ŋ"), - ("ɰ", "ɯ"), - ("ʝ", "j"), - ("ʌ", "ə"), - ("ɡ", "g"), - ("\u031a", "#"), - ("\u0348", "="), - ("\u031e", ""), - ("\u0320", ""), - ("\u0339", ""), - ] -] - - -def fix_g2pk2_error(text): - new_text = "" - i = 0 - while i < len(text) - 4: - if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "ㄹ": - new_text += text[i : i + 3] + " " + "ㄴ" - i += 5 - else: - new_text += text[i] - i += 1 - - new_text += text[i:] - return new_text - - -def latin_to_hangul(text): - for regex, replacement in _latin_to_hangul: - text = re.sub(regex, replacement, text) - return text - - -def divide_hangul(text): - text = j2hcj(h2j(text)) - for regex, replacement in _hangul_divided: - text = re.sub(regex, replacement, text) - return text - - -def hangul_number(num, sino=True): - """Reference https://github.com/Kyubyong/g2pK""" - num = re.sub(",", "", num) - - if num == "0": - return "영" - if not sino and num == "20": - return "스무" - - digits = "123456789" - names = "일이삼사오육칠팔구" - digit2name = {d: n for d, n in zip(digits, names)} - - modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉" - decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔" - digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} - digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} - - spelledout = [] - for i, digit in enumerate(num): - i = len(num) - i - 1 - if sino: - if i == 0: - name = digit2name.get(digit, "") - elif i == 1: - name = digit2name.get(digit, "") + "십" - name = name.replace("일십", "십") - else: - if i == 0: - name = digit2mod.get(digit, "") - elif i == 1: - name = digit2dec.get(digit, "") - if digit == "0": - if i % 4 == 0: - last_three = spelledout[-min(3, len(spelledout)) :] - if "".join(last_three) == "": - spelledout.append("") - continue - else: - spelledout.append("") - continue - if i == 2: - name = digit2name.get(digit, "") + "백" - name = name.replace("일백", "백") - elif i == 3: - name = digit2name.get(digit, "") + "천" - name = name.replace("일천", "천") - elif i == 4: - name = digit2name.get(digit, "") + "만" - name = name.replace("일만", "만") - elif i == 5: - name = digit2name.get(digit, "") + "십" - name = name.replace("일십", "십") - elif i == 6: - name = digit2name.get(digit, "") + "백" - name = name.replace("일백", "백") - elif i == 7: - name = digit2name.get(digit, "") + "천" - name = name.replace("일천", "천") - elif i == 8: - name = digit2name.get(digit, "") + "억" - elif i == 9: - name = digit2name.get(digit, "") + "십" - elif i == 10: - name = digit2name.get(digit, "") + "백" - elif i == 11: - name = digit2name.get(digit, "") + "천" - elif i == 12: - name = digit2name.get(digit, "") + "조" - elif i == 13: - name = digit2name.get(digit, "") + "십" - elif i == 14: - name = digit2name.get(digit, "") + "백" - elif i == 15: - name = digit2name.get(digit, "") + "천" - spelledout.append(name) - return "".join(elem for elem in spelledout) - - -def number_to_hangul(text): - """Reference https://github.com/Kyubyong/g2pK""" - tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text)) - for token in tokens: - num, classifier = token - if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: - spelledout = hangul_number(num, sino=False) - else: - spelledout = hangul_number(num, sino=True) - text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}") - # digit by digit for remaining digits - digits = "0123456789" - names = "영일이삼사오육칠팔구" - for d, n in zip(digits, names): - text = text.replace(d, n) - return text - - -def korean_to_lazy_ipa(text): - text = latin_to_hangul(text) - text = number_to_hangul(text) - text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text) - for regex, replacement in _ipa_to_lazy_ipa: - text = re.sub(regex, replacement, text) - return text - - -_g2p = G2p() - - -def korean_to_ipa(text): - text = latin_to_hangul(text) - text = number_to_hangul(text) - text = _g2p(text) - text = fix_g2pk2_error(text) - text = korean_to_lazy_ipa(text) - return text.replace("ʧ", "tʃ").replace("ʥ", "dʑ") - - -def post_replace_ph(ph): - rep_map = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - "·": ",", - "、": ",", - "...": "…", - " ": "空", - } - if ph in rep_map.keys(): - ph = rep_map[ph] - if ph in symbols: - return ph - if ph not in symbols: - ph = "停" - return ph - - -def g2p(text): - text = latin_to_hangul(text) - text = _g2p(text) - text = divide_hangul(text) - text = fix_g2pk2_error(text) - text = re.sub(r"([\u3131-\u3163])$", r"\1.", text) - # text = "".join([post_replace_ph(i) for i in text]) - text = [post_replace_ph(i) for i in text] - return text - - -if __name__ == "__main__": - text = "안녕하세요" - print(g2p(text)) +# reference: https://github.com/ORI-Muchim/MB-iSTFT-VITS-Korean/blob/main/text/korean.py + +import re +from jamo import h2j, j2hcj +import ko_pron +from g2pk2 import G2p + +import importlib +import os + +# 防止win下无法读取模型 +if os.name == "nt": + + class win_G2p(G2p): + def check_mecab(self): + super().check_mecab() + spam_spec = importlib.util.find_spec("eunjeon") + non_found = spam_spec is None + if non_found: + print("you have to install eunjeon. install it...") + else: + installpath = spam_spec.submodule_search_locations[0] + if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): + import sys + from eunjeon import Mecab as _Mecab + + class Mecab(_Mecab): + def get_dicpath(installpath): + if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): + import shutil + + python_dir = os.getcwd() + if installpath[: len(python_dir)].upper() == python_dir.upper(): + dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc") + else: + if not os.path.exists("TEMP"): + os.mkdir("TEMP") + if not os.path.exists(os.path.join("TEMP", "ko")): + os.mkdir(os.path.join("TEMP", "ko")) + if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")): + shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict")) + + shutil.copytree( + os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict") + ) + dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc") + else: + dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc")) + return dicpath + + def __init__(self, dicpath=get_dicpath(installpath)): + super().__init__(dicpath=dicpath) + + sys.modules["eunjeon"].Mecab = Mecab + + G2p = win_G2p + + +from text.symbols2 import symbols + +# This is a list of Korean classifiers preceded by pure Korean numerals. +_korean_classifiers = ( + "군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통" +) + +# List of (hangul, hangul divided) pairs: +_hangul_divided = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + # ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule + # ('ㄵ', 'ㄴㅈ'), + # ('ㄶ', 'ㄴㅎ'), + # ('ㄺ', 'ㄹㄱ'), + # ('ㄻ', 'ㄹㅁ'), + # ('ㄼ', 'ㄹㅂ'), + # ('ㄽ', 'ㄹㅅ'), + # ('ㄾ', 'ㄹㅌ'), + # ('ㄿ', 'ㄹㅍ'), + # ('ㅀ', 'ㄹㅎ'), + # ('ㅄ', 'ㅂㅅ'), + ("ㅘ", "ㅗㅏ"), + ("ㅙ", "ㅗㅐ"), + ("ㅚ", "ㅗㅣ"), + ("ㅝ", "ㅜㅓ"), + ("ㅞ", "ㅜㅔ"), + ("ㅟ", "ㅜㅣ"), + ("ㅢ", "ㅡㅣ"), + ("ㅑ", "ㅣㅏ"), + ("ㅒ", "ㅣㅐ"), + ("ㅕ", "ㅣㅓ"), + ("ㅖ", "ㅣㅔ"), + ("ㅛ", "ㅣㅗ"), + ("ㅠ", "ㅣㅜ"), + ] +] + +# List of (Latin alphabet, hangul) pairs: +_latin_to_hangul = [ + (re.compile("%s" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("a", "에이"), + ("b", "비"), + ("c", "시"), + ("d", "디"), + ("e", "이"), + ("f", "에프"), + ("g", "지"), + ("h", "에이치"), + ("i", "아이"), + ("j", "제이"), + ("k", "케이"), + ("l", "엘"), + ("m", "엠"), + ("n", "엔"), + ("o", "오"), + ("p", "피"), + ("q", "큐"), + ("r", "아르"), + ("s", "에스"), + ("t", "티"), + ("u", "유"), + ("v", "브이"), + ("w", "더블유"), + ("x", "엑스"), + ("y", "와이"), + ("z", "제트"), + ] +] + +# List of (ipa, lazy ipa) pairs: +_ipa_to_lazy_ipa = [ + (re.compile("%s" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("t͡ɕ", "ʧ"), + ("d͡ʑ", "ʥ"), + ("ɲ", "n^"), + ("ɕ", "ʃ"), + ("ʷ", "w"), + ("ɭ", "l`"), + ("ʎ", "ɾ"), + ("ɣ", "ŋ"), + ("ɰ", "ɯ"), + ("ʝ", "j"), + ("ʌ", "ə"), + ("ɡ", "g"), + ("\u031a", "#"), + ("\u0348", "="), + ("\u031e", ""), + ("\u0320", ""), + ("\u0339", ""), + ] +] + + +def fix_g2pk2_error(text): + new_text = "" + i = 0 + while i < len(text) - 4: + if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "ㄹ": + new_text += text[i : i + 3] + " " + "ㄴ" + i += 5 + else: + new_text += text[i] + i += 1 + + new_text += text[i:] + return new_text + + +def latin_to_hangul(text): + for regex, replacement in _latin_to_hangul: + text = re.sub(regex, replacement, text) + return text + + +def divide_hangul(text): + text = j2hcj(h2j(text)) + for regex, replacement in _hangul_divided: + text = re.sub(regex, replacement, text) + return text + + +def hangul_number(num, sino=True): + """Reference https://github.com/Kyubyong/g2pK""" + num = re.sub(",", "", num) + + if num == "0": + return "영" + if not sino and num == "20": + return "스무" + + digits = "123456789" + names = "일이삼사오육칠팔구" + digit2name = {d: n for d, n in zip(digits, names)} + + modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉" + decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔" + digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} + digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} + + spelledout = [] + for i, digit in enumerate(num): + i = len(num) - i - 1 + if sino: + if i == 0: + name = digit2name.get(digit, "") + elif i == 1: + name = digit2name.get(digit, "") + "십" + name = name.replace("일십", "십") + else: + if i == 0: + name = digit2mod.get(digit, "") + elif i == 1: + name = digit2dec.get(digit, "") + if digit == "0": + if i % 4 == 0: + last_three = spelledout[-min(3, len(spelledout)) :] + if "".join(last_three) == "": + spelledout.append("") + continue + else: + spelledout.append("") + continue + if i == 2: + name = digit2name.get(digit, "") + "백" + name = name.replace("일백", "백") + elif i == 3: + name = digit2name.get(digit, "") + "천" + name = name.replace("일천", "천") + elif i == 4: + name = digit2name.get(digit, "") + "만" + name = name.replace("일만", "만") + elif i == 5: + name = digit2name.get(digit, "") + "십" + name = name.replace("일십", "십") + elif i == 6: + name = digit2name.get(digit, "") + "백" + name = name.replace("일백", "백") + elif i == 7: + name = digit2name.get(digit, "") + "천" + name = name.replace("일천", "천") + elif i == 8: + name = digit2name.get(digit, "") + "억" + elif i == 9: + name = digit2name.get(digit, "") + "십" + elif i == 10: + name = digit2name.get(digit, "") + "백" + elif i == 11: + name = digit2name.get(digit, "") + "천" + elif i == 12: + name = digit2name.get(digit, "") + "조" + elif i == 13: + name = digit2name.get(digit, "") + "십" + elif i == 14: + name = digit2name.get(digit, "") + "백" + elif i == 15: + name = digit2name.get(digit, "") + "천" + spelledout.append(name) + return "".join(elem for elem in spelledout) + + +def number_to_hangul(text): + """Reference https://github.com/Kyubyong/g2pK""" + tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text)) + for token in tokens: + num, classifier = token + if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: + spelledout = hangul_number(num, sino=False) + else: + spelledout = hangul_number(num, sino=True) + text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}") + # digit by digit for remaining digits + digits = "0123456789" + names = "영일이삼사오육칠팔구" + for d, n in zip(digits, names): + text = text.replace(d, n) + return text + + +def korean_to_lazy_ipa(text): + text = latin_to_hangul(text) + text = number_to_hangul(text) + text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text) + for regex, replacement in _ipa_to_lazy_ipa: + text = re.sub(regex, replacement, text) + return text + + +_g2p = G2p() + + +def korean_to_ipa(text): + text = latin_to_hangul(text) + text = number_to_hangul(text) + text = _g2p(text) + text = fix_g2pk2_error(text) + text = korean_to_lazy_ipa(text) + return text.replace("ʧ", "tʃ").replace("ʥ", "dʑ") + + +def post_replace_ph(ph): + rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + " ": "空", + } + if ph in rep_map.keys(): + ph = rep_map[ph] + if ph in symbols: + return ph + if ph not in symbols: + ph = "停" + return ph + + +def g2p(text): + text = latin_to_hangul(text) + text = _g2p(text) + text = divide_hangul(text) + text = fix_g2pk2_error(text) + text = re.sub(r"([\u3131-\u3163])$", r"\1.", text) + # text = "".join([post_replace_ph(i) for i in text]) + text = [post_replace_ph(i) for i in text] + return text + + +# Helper for alignment: build phones and word2ph by space-tokenizing Korean text +def g2p_with_word2ph(text, keep_punc=False): + """ + Returns (phones, word2ph) + - Tokenize by spaces; for each non-space token, call g2p(token) + - Exclude separator/placeholders from phones: '.', '空', '停' and common punctuations + - word2ph per token = max(1, num_valid_phones) if keep_punc else skip spaces entirely + """ + # prepare + tokens = re.split(r"(\s+)", text) + phones_all = [] + word2ph = [] + punc_set = {",", ".", "!", "?", "、", "。", ";", ":", ":"} + for tok in tokens: + if tok.strip() == "": + if keep_punc: + word2ph.append(1) + continue + phs = g2p(tok) + # filter non-phonetic markers + phs_valid = [p for p in phs if p not in {'.', '空', '停'} and p not in punc_set] + phones_all.extend(phs_valid) + word2ph.append(max(1, len(phs_valid))) + return phones_all, word2ph + +if __name__ == "__main__": + text = "안녕하세요" + print(g2p(text)) diff --git a/api.py b/api.py index 8a9c36e5..6df9ff77 100644 --- a/api.py +++ b/api.py @@ -1,1686 +1,1808 @@ -""" -# api.py usage - -` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` - -## 执行参数: - -`-s` - `SoVITS模型路径, 可在 config.py 中指定` -`-g` - `GPT模型路径, 可在 config.py 中指定` - -调用请求缺少参考音频时使用 -`-dr` - `默认参考音频路径` -`-dt` - `默认参考音频文本` -`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"` - -`-d` - `推理设备, "cuda","cpu"` -`-a` - `绑定地址, 默认"127.0.0.1"` -`-p` - `绑定端口, 默认9880, 可在 config.py 中指定` -`-fp` - `覆盖 config.py 使用全精度` -`-hp` - `覆盖 config.py 使用半精度` -`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` -·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` -·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"` -·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入` - -`-hb` - `cnhubert路径` -`-b` - `bert路径` - -## 调用: - -### 推理 - -endpoint: `/` - -使用执行参数指定的参考音频: -GET: - `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` -POST: -```json -{ - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh" -} -``` - -使用执行参数指定的参考音频并设定分割符号: -GET: - `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&cut_punc=,。` -POST: -```json -{ - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh", - "cut_punc": ",。", -} -``` - -手动指定当次推理所使用的参考音频: -GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` -POST: -```json -{ - "refer_wav_path": "123.wav", - "prompt_text": "一二三。", - "prompt_language": "zh", - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh" -} -``` - -RESP: -成功: 直接返回 wav 音频流, http code 200 -失败: 返回包含错误信息的 json, http code 400 - -手动指定当次推理所使用的参考音频,并提供参数: -GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"` -POST: -```json -{ - "refer_wav_path": "123.wav", - "prompt_text": "一二三。", - "prompt_language": "zh", - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh", - "top_k": 20, - "top_p": 0.6, - "temperature": 0.6, - "speed": 1, - "inp_refs": ["456.wav","789.wav"] -} -``` - -RESP: -成功: 直接返回 wav 音频流, http code 200 -失败: 返回包含错误信息的 json, http code 400 - - -### 更换默认参考音频 - -endpoint: `/change_refer` - -key与推理端一样 - -GET: - `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` -POST: -```json -{ - "refer_wav_path": "123.wav", - "prompt_text": "一二三。", - "prompt_language": "zh" -} -``` - -RESP: -成功: json, http code 200 -失败: json, 400 - - -### 命令控制 - -endpoint: `/control` - -command: -"restart": 重新运行 -"exit": 结束运行 - -GET: - `http://127.0.0.1:9880/control?command=restart` -POST: -```json -{ - "command": "restart" -} -``` - -RESP: 无 - -""" - -import argparse -import os -import re -import sys - -now_dir = os.getcwd() -sys.path.append(now_dir) -sys.path.append("%s/GPT_SoVITS" % (now_dir)) - -import signal -from text.LangSegmenter import LangSegmenter -from time import time as ttime -import torch -import torchaudio -import librosa -import soundfile as sf -from fastapi import FastAPI, Request, Query -from fastapi.responses import StreamingResponse, JSONResponse -import uvicorn -from transformers import AutoModelForMaskedLM, AutoTokenizer -import numpy as np -from feature_extractor import cnhubert -from io import BytesIO -from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3 -from peft import LoraConfig, get_peft_model -from AR.models.t2s_lightning_module import Text2SemanticLightningModule -from text import cleaned_text_to_sequence -from text.cleaner import clean_text -from module.mel_processing import spectrogram_torch -import config as global_config -import logging -import subprocess - - -class DefaultRefer: - def __init__(self, path, text, language): - self.path = args.default_refer_path - self.text = args.default_refer_text - self.language = args.default_refer_language - - def is_ready(self) -> bool: - return is_full(self.path, self.text, self.language) - - -def is_empty(*items): # 任意一项不为空返回False - for item in items: - if item is not None and item != "": - return False - return True - - -def is_full(*items): # 任意一项为空返回False - for item in items: - if item is None or item == "": - return False - return True - - -bigvgan_model = hifigan_model = sv_cn_model = None - - -def clean_hifigan_model(): - global hifigan_model - if hifigan_model: - hifigan_model = hifigan_model.cpu() - hifigan_model = None - try: - torch.cuda.empty_cache() - except: - pass - - -def clean_bigvgan_model(): - global bigvgan_model - if bigvgan_model: - bigvgan_model = bigvgan_model.cpu() - bigvgan_model = None - try: - torch.cuda.empty_cache() - except: - pass - - -def clean_sv_cn_model(): - global sv_cn_model - if sv_cn_model: - sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu() - sv_cn_model = None - try: - torch.cuda.empty_cache() - except: - pass - - -def init_bigvgan(): - global bigvgan_model, hifigan_model, sv_cn_model - from BigVGAN import bigvgan - - bigvgan_model = bigvgan.BigVGAN.from_pretrained( - "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), - use_cuda_kernel=False, - ) # if True, RuntimeError: Ninja is required to load C++ extensions - # remove weight norm in the model and set to eval mode - bigvgan_model.remove_weight_norm() - bigvgan_model = bigvgan_model.eval() - - if is_half == True: - bigvgan_model = bigvgan_model.half().to(device) - else: - bigvgan_model = bigvgan_model.to(device) - - -def init_hifigan(): - global hifigan_model, bigvgan_model, sv_cn_model - hifigan_model = Generator( - initial_channel=100, - resblock="1", - resblock_kernel_sizes=[3, 7, 11], - resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - upsample_rates=[10, 6, 2, 2, 2], - upsample_initial_channel=512, - upsample_kernel_sizes=[20, 12, 4, 4, 4], - gin_channels=0, - is_bias=True, - ) - hifigan_model.eval() - hifigan_model.remove_weight_norm() - state_dict_g = torch.load( - "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), - map_location="cpu", - weights_only=False, - ) - print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) - if is_half == True: - hifigan_model = hifigan_model.half().to(device) - else: - hifigan_model = hifigan_model.to(device) - - -from sv import SV - - -def init_sv_cn(): - global hifigan_model, bigvgan_model, sv_cn_model - sv_cn_model = SV(device, is_half) - - -resample_transform_dict = {} - - -def resample(audio_tensor, sr0, sr1, device): - global resample_transform_dict - key = "%s-%s-%s" % (sr0, sr1, str(device)) - if key not in resample_transform_dict: - resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) - return resample_transform_dict[key](audio_tensor) - - -from module.mel_processing import mel_spectrogram_torch - -spec_min = -12 -spec_max = 2 - - -def norm_spec(x): - return (x - spec_min) / (spec_max - spec_min) * 2 - 1 - - -def denorm_spec(x): - return (x + 1) / 2 * (spec_max - spec_min) + spec_min - - -mel_fn = lambda x: mel_spectrogram_torch( - x, - **{ - "n_fft": 1024, - "win_size": 1024, - "hop_size": 256, - "num_mels": 100, - "sampling_rate": 24000, - "fmin": 0, - "fmax": None, - "center": False, - }, -) -mel_fn_v4 = lambda x: mel_spectrogram_torch( - x, - **{ - "n_fft": 1280, - "win_size": 1280, - "hop_size": 320, - "num_mels": 100, - "sampling_rate": 32000, - "fmin": 0, - "fmax": None, - "center": False, - }, -) - - -sr_model = None - - -def audio_sr(audio, sr): - global sr_model - if sr_model == None: - from tools.audio_sr import AP_BWE - - try: - sr_model = AP_BWE(device, DictToAttrRecursive) - except FileNotFoundError: - logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载") - return audio.cpu().detach().numpy(), sr - return sr_model(audio, sr) - - -class Speaker: - def __init__(self, name, gpt, sovits, phones=None, bert=None, prompt=None): - self.name = name - self.sovits = sovits - self.gpt = gpt - self.phones = phones - self.bert = bert - self.prompt = prompt - - -speaker_list = {} - - -class Sovits: - def __init__(self, vq_model, hps): - self.vq_model = vq_model - self.hps = hps - - -from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new - - -def get_sovits_weights(sovits_path): - from config import pretrained_sovits_name - - path_sovits_v3 = pretrained_sovits_name["v3"] - path_sovits_v4 = pretrained_sovits_name["v4"] - is_exist_s2gv3 = os.path.exists(path_sovits_v3) - is_exist_s2gv4 = os.path.exists(path_sovits_v4) - - version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) - is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 - path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 - - if if_lora_v3 == True and is_exist == False: - logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version) - - dict_s2 = load_sovits_new(sovits_path) - hps = dict_s2["config"] - hps = DictToAttrRecursive(hps) - hps.model.semantic_frame_rate = "25hz" - if "enc_p.text_embedding.weight" not in dict_s2["weight"]: - hps.model.version = "v2" # v3model,v2sybomls - elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: - hps.model.version = "v1" - else: - hps.model.version = "v2" - - model_params_dict = vars(hps.model) - if model_version not in {"v3", "v4"}: - if "Pro" in model_version: - hps.model.version = model_version - if sv_cn_model == None: - init_sv_cn() - - vq_model = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **model_params_dict, - ) - else: - hps.model.version = model_version - vq_model = SynthesizerTrnV3( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **model_params_dict, - ) - if model_version == "v3": - init_bigvgan() - if model_version == "v4": - init_hifigan() - - model_version = hps.model.version - logger.info(f"模型版本: {model_version}") - if "pretrained" not in sovits_path: - try: - del vq_model.enc_q - except: - pass - if is_half == True: - vq_model = vq_model.half().to(device) - else: - vq_model = vq_model.to(device) - vq_model.eval() - if if_lora_v3 == False: - vq_model.load_state_dict(dict_s2["weight"], strict=False) - else: - path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 - vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False) - lora_rank = dict_s2["lora_rank"] - lora_config = LoraConfig( - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - r=lora_rank, - lora_alpha=lora_rank, - init_lora_weights=True, - ) - vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) - vq_model.load_state_dict(dict_s2["weight"], strict=False) - vq_model.cfm = vq_model.cfm.merge_and_unload() - # torch.save(vq_model.state_dict(),"merge_win.pth") - vq_model.eval() - - sovits = Sovits(vq_model, hps) - return sovits - - -class Gpt: - def __init__(self, max_sec, t2s_model): - self.max_sec = max_sec - self.t2s_model = t2s_model - - -global hz -hz = 50 - - -def get_gpt_weights(gpt_path): - dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) - config = dict_s1["config"] - max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) - t2s_model.load_state_dict(dict_s1["weight"]) - if is_half == True: - t2s_model = t2s_model.half() - t2s_model = t2s_model.to(device) - t2s_model.eval() - # total = sum([param.nelement() for param in t2s_model.parameters()]) - # logger.info("Number of parameter: %.2fM" % (total / 1e6)) - - gpt = Gpt(max_sec, t2s_model) - return gpt - - -def change_gpt_sovits_weights(gpt_path, sovits_path): - try: - gpt = get_gpt_weights(gpt_path) - sovits = get_sovits_weights(sovits_path) - except Exception as e: - return JSONResponse({"code": 400, "message": str(e)}, status_code=400) - - speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) - return JSONResponse({"code": 0, "message": "Success"}, status_code=200) - - -def get_bert_feature(text, word2ph): - with torch.no_grad(): - inputs = tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model - res = bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] - assert len(word2ph) == len(text) - phone_level_feature = [] - for i in range(len(word2ph)): - repeat_feature = res[i].repeat(word2ph[i], 1) - phone_level_feature.append(repeat_feature) - phone_level_feature = torch.cat(phone_level_feature, dim=0) - # if(is_half==True):phone_level_feature=phone_level_feature.half() - return phone_level_feature.T - - -def clean_text_inf(text, language, version): - language = language.replace("all_", "") - phones, word2ph, norm_text = clean_text(text, language, version) - phones = cleaned_text_to_sequence(phones, version) - return phones, word2ph, norm_text - - -def get_bert_inf(phones, word2ph, norm_text, language): - language = language.replace("all_", "") - if language == "zh": - bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) - else: - bert = torch.zeros( - (1024, len(phones)), - dtype=torch.float16 if is_half == True else torch.float32, - ).to(device) - - return bert - - -from text import chinese - - -def get_phones_and_bert(text, language, version, final=False): - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) - bert = get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) - - if not final and len(phones) < 6: - return get_phones_and_bert("." + text, language, version, final=True) - - return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text - - -class DictToAttrRecursive(dict): - def __init__(self, input_dict): - super().__init__(input_dict) - for key, value in input_dict.items(): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - self[key] = value - setattr(self, key, value) - - def __getattr__(self, item): - try: - return self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - def __setattr__(self, key, value): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - super(DictToAttrRecursive, self).__setitem__(key, value) - super().__setattr__(key, value) - - def __delattr__(self, item): - try: - del self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - -def get_spepc(hps, filename, dtype, device, is_v2pro=False): - sr1 = int(hps.data.sampling_rate) - audio, sr0 = torchaudio.load(filename) - if sr0 != sr1: - audio = audio.to(device) - if audio.shape[0] == 2: - audio = audio.mean(0).unsqueeze(0) - audio = resample(audio, sr0, sr1, device) - else: - audio = audio.to(device) - if audio.shape[0] == 2: - audio = audio.mean(0).unsqueeze(0) - - maxx = audio.abs().max() - if maxx > 1: - audio /= min(2, maxx) - spec = spectrogram_torch( - audio, - hps.data.filter_length, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - center=False, - ) - spec = spec.to(dtype) - if is_v2pro == True: - audio = resample(audio, sr1, 16000, device).to(dtype) - return spec, audio - - -def pack_audio(audio_bytes, data, rate): - if media_type == "ogg": - audio_bytes = pack_ogg(audio_bytes, data, rate) - elif media_type == "aac": - audio_bytes = pack_aac(audio_bytes, data, rate) - else: - # wav无法流式, 先暂存raw - audio_bytes = pack_raw(audio_bytes, data, rate) - - return audio_bytes - - -def pack_ogg(audio_bytes, data, rate): - # Author: AkagawaTsurunaki - # Issue: - # Stack overflow probabilistically occurs - # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called - # using the Python library `soundfile` - # Note: - # This is an issue related to `libsndfile`, not this project itself. - # It happens when you generate a large audio tensor (about 499804 frames in my PC) - # and try to convert it to an ogg file. - # Related: - # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 - # https://github.com/libsndfile/libsndfile/issues/1023 - # https://github.com/bastibe/python-soundfile/issues/396 - # Suggestion: - # Or split the whole audio data into smaller audio segment to avoid stack overflow? - - def handle_pack_ogg(): - with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) - - import threading - - # See: https://docs.python.org/3/library/threading.html - # The stack size of this thread is at least 32768 - # If stack overflow error still occurs, just modify the `stack_size`. - # stack_size = n * 4096, where n should be a positive integer. - # Here we chose n = 4096. - stack_size = 4096 * 4096 - try: - threading.stack_size(stack_size) - pack_ogg_thread = threading.Thread(target=handle_pack_ogg) - pack_ogg_thread.start() - pack_ogg_thread.join() - except RuntimeError as e: - # If changing the thread stack size is unsupported, a RuntimeError is raised. - print("RuntimeError: {}".format(e)) - print("Changing the thread stack size is unsupported.") - except ValueError as e: - # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. - print("ValueError: {}".format(e)) - print("The specified stack size is invalid.") - - return audio_bytes - - -def pack_raw(audio_bytes, data, rate): - audio_bytes.write(data.tobytes()) - - return audio_bytes - - -def pack_wav(audio_bytes, rate): - if is_int32: - data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int32) - wav_bytes = BytesIO() - sf.write(wav_bytes, data, rate, format="WAV", subtype="PCM_32") - else: - data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16) - wav_bytes = BytesIO() - sf.write(wav_bytes, data, rate, format="WAV") - return wav_bytes - - -def pack_aac(audio_bytes, data, rate): - if is_int32: - pcm = "s32le" - bit_rate = "256k" - else: - pcm = "s16le" - bit_rate = "128k" - process = subprocess.Popen( - [ - "ffmpeg", - "-f", - pcm, # 输入16位有符号小端整数PCM - "-ar", - str(rate), # 设置采样率 - "-ac", - "1", # 单声道 - "-i", - "pipe:0", # 从管道读取输入 - "-c:a", - "aac", # 音频编码器为AAC - "-b:a", - bit_rate, # 比特率 - "-vn", # 不包含视频 - "-f", - "adts", # 输出AAC数据流格式 - "pipe:1", # 将输出写入管道 - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, _ = process.communicate(input=data.tobytes()) - audio_bytes.write(out) - - return audio_bytes - - -def read_clean_buffer(audio_bytes): - audio_chunk = audio_bytes.getvalue() - audio_bytes.truncate(0) - audio_bytes.seek(0) - - return audio_bytes, audio_chunk - - -def cut_text(text, punc): - punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}] - if len(punc_list) > 0: - punds = r"[" + "".join(punc_list) + r"]" - text = text.strip("\n") - items = re.split(f"({punds})", text) - mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] - # 在句子不存在符号或句尾无符号的时候保证文本完整 - if len(items) % 2 == 1: - mergeitems.append(items[-1]) - text = "\n".join(mergeitems) - - while "\n\n" in text: - text = text.replace("\n\n", "\n") - - return text - - -def only_punc(text): - return not any(t.isalnum() or t.isalpha() for t in text) - - -splits = { - ",", - "。", - "?", - "!", - ",", - ".", - "?", - "!", - "~", - ":", - ":", - "—", - "…", -} - - -def get_tts_wav( - ref_wav_path, - prompt_text, - prompt_language, - text, - text_language, - top_k=15, - top_p=0.6, - temperature=0.6, - speed=1, - inp_refs=None, - sample_steps=32, - if_sr=False, - spk="default", -): - infer_sovits = speaker_list[spk].sovits - vq_model = infer_sovits.vq_model - hps = infer_sovits.hps - version = vq_model.version - - infer_gpt = speaker_list[spk].gpt - t2s_model = infer_gpt.t2s_model - max_sec = infer_gpt.max_sec - - if version == "v3": - if sample_steps not in [4, 8, 16, 32, 64, 128]: - sample_steps = 32 - elif version == "v4": - if sample_steps not in [4, 8, 16, 32]: - sample_steps = 8 - - if if_sr and version != "v3": - if_sr = False - - t0 = ttime() - prompt_text = prompt_text.strip("\n") - if prompt_text[-1] not in splits: - prompt_text += "。" if prompt_language != "en" else "." - prompt_language, text = prompt_language, text.strip("\n") - dtype = torch.float16 if is_half == True else torch.float32 - zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - if is_half == True: - wav16k = wav16k.half().to(device) - zero_wav_torch = zero_wav_torch.half().to(device) - else: - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() - codes = vq_model.extract_latent(ssl_content) - prompt_semantic = codes[0, 0] - prompt = prompt_semantic.unsqueeze(0).to(device) - - is_v2pro = version in {"v2Pro", "v2ProPlus"} - if version not in {"v3", "v4"}: - refers = [] - if is_v2pro: - sv_emb = [] - if sv_cn_model == None: - init_sv_cn() - if inp_refs: - for path in inp_refs: - try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer - refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) - refers.append(refer) - if is_v2pro: - sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) - except Exception as e: - logger.error(e) - if len(refers) == 0: - refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) - refers = [refers] - if is_v2pro: - sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] - else: - refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) - - t1 = ttime() - # os.environ['version'] = version - prompt_language = dict_language[prompt_language.lower()] - text_language = dict_language[text_language.lower()] - phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) - texts = text.split("\n") - audio_bytes = BytesIO() - - for text in texts: - # 简单防止纯符号引发参考音频泄露 - if only_punc(text): - continue - - audio_opt = [] - if text[-1] not in splits: - text += "。" if text_language != "en" else "." - phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) - bert = torch.cat([bert1, bert2], 1) - - all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - t2 = ttime() - with torch.no_grad(): - pred_semantic, idx = t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_len, - prompt, - bert, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=hz * max_sec, - ) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) - t3 = ttime() - - if version not in {"v3", "v4"}: - if is_v2pro: - o, attn, y_mask = vq_model.decode_with_alignment( - pred_semantic, - torch.LongTensor(phones2).to(device).unsqueeze(0), - refers, - speed=speed, - sv_emb=sv_emb, - ) - audio = o.detach().cpu().numpy()[0, 0] - else: - o, attn, y_mask = vq_model.decode_with_alignment( - pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed - ) - audio = o.detach().cpu().numpy()[0, 0] - else: - phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) - phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) - - fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) - ref_audio, sr = torchaudio.load(ref_wav_path) - ref_audio = ref_audio.to(device).float() - if ref_audio.shape[0] == 2: - ref_audio = ref_audio.mean(0).unsqueeze(0) - - tgt_sr = 24000 if version == "v3" else 32000 - if sr != tgt_sr: - ref_audio = resample(ref_audio, sr, tgt_sr, device) - mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) - mel2 = norm_spec(mel2) - T_min = min(mel2.shape[2], fea_ref.shape[2]) - mel2 = mel2[:, :, :T_min] - fea_ref = fea_ref[:, :, :T_min] - Tref = 468 if version == "v3" else 500 - Tchunk = 934 if version == "v3" else 1000 - if T_min > Tref: - mel2 = mel2[:, :, -Tref:] - fea_ref = fea_ref[:, :, -Tref:] - T_min = Tref - chunk_len = Tchunk - T_min - mel2 = mel2.to(dtype) - fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) - cfm_resss = [] - idx = 0 - while 1: - fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] - if fea_todo_chunk.shape[-1] == 0: - break - idx += chunk_len - fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) - cfm_res = vq_model.cfm.inference( - fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 - ) - cfm_res = cfm_res[:, :, mel2.shape[2] :] - mel2 = cfm_res[:, :, -T_min:] - fea_ref = fea_todo_chunk[:, :, -T_min:] - cfm_resss.append(cfm_res) - cfm_res = torch.cat(cfm_resss, 2) - cfm_res = denorm_spec(cfm_res) - if version == "v3": - if bigvgan_model == None: - init_bigvgan() - else: # v4 - if hifigan_model == None: - init_hifigan() - vocoder_model = bigvgan_model if version == "v3" else hifigan_model - with torch.inference_mode(): - wav_gen = vocoder_model(cfm_res) - audio = wav_gen[0][0].cpu().detach().numpy() - - max_audio = np.abs(audio).max() - if max_audio > 1: - audio /= max_audio - audio_opt.append(audio) - audio_opt.append(zero_wav) - audio_opt = np.concatenate(audio_opt, 0) - t4 = ttime() - - if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: - sr = 32000 - elif version == "v3": - sr = 24000 - else: - sr = 48000 # v4 - - if if_sr and sr == 24000: - audio_opt = torch.from_numpy(audio_opt).float().to(device) - audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) - max_audio = np.abs(audio_opt).max() - if max_audio > 1: - audio_opt /= max_audio - sr = 48000 - - if is_int32: - audio_bytes = pack_audio(audio_bytes, (audio_opt * 2147483647).astype(np.int32), sr) - else: - audio_bytes = pack_audio(audio_bytes, (audio_opt * 32768).astype(np.int16), sr) - # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - if stream_mode == "normal": - audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) - # For backward compatibility, yield audio chunk only - yield audio_chunk - - if not stream_mode == "normal": - if media_type == "wav": - if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: - sr = 32000 - elif version == "v3": - sr = 48000 if if_sr else 24000 - else: - sr = 48000 # v4 - audio_bytes = pack_wav(audio_bytes, sr) - # extend: for stream_mode!=normal we still return only audio bytes for backward compatibility - yield audio_bytes.getvalue() - - -def handle_control(command): - if command == "restart": - os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) - elif command == "exit": - os.kill(os.getpid(), signal.SIGTERM) - exit(0) - - -def handle_change(path, text, language): - if is_empty(path, text, language): - return JSONResponse( - {"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400 - ) - - if path != "" or path is not None: - default_refer.path = path - if text != "" or text is not None: - default_refer.text = text - if language != "" or language is not None: - default_refer.language = language - - logger.info(f"当前默认参考音频路径: {default_refer.path}") - logger.info(f"当前默认参考音频文本: {default_refer.text}") - logger.info(f"当前默认参考音频语种: {default_refer.language}") - logger.info(f"is_ready: {default_refer.is_ready()}") - - return JSONResponse({"code": 0, "message": "Success"}, status_code=200) - - -def handle( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - cut_punc, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, -): - if ( - refer_wav_path == "" - or refer_wav_path is None - or prompt_text == "" - or prompt_text is None - or prompt_language == "" - or prompt_language is None - ): - refer_wav_path, prompt_text, prompt_language = ( - default_refer.path, - default_refer.text, - default_refer.language, - ) - if not default_refer.is_ready(): - return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) - - if cut_punc == None: - text = cut_text(text, default_cut_punc) - else: - text = cut_text(text, cut_punc) - - gen = get_tts_wav( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, - ) - # Consume the generator to collect bytes and timestamps; pack as multipart/mixed JSON with audio for compatibility - # For simplicity, keep legacy streaming behaviour - return StreamingResponse(gen, media_type="audio/" + media_type) - - -# -------------------------------- -# 初始化部分 -# -------------------------------- -dict_language = { - "中文": "all_zh", - "粤语": "all_yue", - "英文": "en", - "日文": "all_ja", - "韩文": "all_ko", - "中英混合": "zh", - "粤英混合": "yue", - "日英混合": "ja", - "韩英混合": "ko", - "多语种混合": "auto", # 多语种启动切分识别语种 - "多语种混合(粤语)": "auto_yue", - "all_zh": "all_zh", - "all_yue": "all_yue", - "en": "en", - "all_ja": "all_ja", - "all_ko": "all_ko", - "zh": "zh", - "yue": "yue", - "ja": "ja", - "ko": "ko", - "auto": "auto", - "auto_yue": "auto_yue", -} - -# logger -logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) -logger = logging.getLogger("uvicorn") - -# 获取配置 -g_config = global_config.Config() - -# 获取参数 -parser = argparse.ArgumentParser(description="GPT-SoVITS api") - -parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") -parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") -parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") -parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") -parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") -parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") -parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") -parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") -parser.add_argument( - "-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度" -) -parser.add_argument( - "-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度" -) -# bool值的用法为 `python ./api.py -fp ...` -# 此时 full_precision==True, half_precision==False -parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") -parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") -parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32") -parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") -# 切割常用分句符为 `python ./api.py -cp ".?!。?!"` -parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") -parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") - -args = parser.parse_args() -sovits_path = args.sovits_path -gpt_path = args.gpt_path -device = args.device -port = args.port -host = args.bind_addr -cnhubert_base_path = args.hubert_path -bert_path = args.bert_path -default_cut_punc = args.cut_punc - -# 应用参数配置 -default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) - -# 模型路径检查 -if sovits_path == "": - sovits_path = g_config.pretrained_sovits_path - logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}") -if gpt_path == "": - gpt_path = g_config.pretrained_gpt_path - logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}") - -# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 -if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": - default_refer.path, default_refer.text, default_refer.language = "", "", "" - logger.info("未指定默认参考音频") -else: - logger.info(f"默认参考音频路径: {default_refer.path}") - logger.info(f"默认参考音频文本: {default_refer.text}") - logger.info(f"默认参考音频语种: {default_refer.language}") - -# 获取半精度 -is_half = g_config.is_half -if args.full_precision: - is_half = False -if args.half_precision: - is_half = True -if args.full_precision and args.half_precision: - is_half = g_config.is_half # 炒饭fallback -logger.info(f"半精: {is_half}") - -# 流式返回模式 -if args.stream_mode.lower() in ["normal", "n"]: - stream_mode = "normal" - logger.info("流式返回已开启") -else: - stream_mode = "close" - -# 音频编码格式 -if args.media_type.lower() in ["aac", "ogg"]: - media_type = args.media_type.lower() -elif stream_mode == "close": - media_type = "wav" -else: - media_type = "ogg" -logger.info(f"编码格式: {media_type}") - -# 音频数据类型 -if args.sub_type.lower() == "int32": - is_int32 = True - logger.info("数据类型: int32") -else: - is_int32 = False - logger.info("数据类型: int16") - -# 初始化模型 -cnhubert.cnhubert_base_path = cnhubert_base_path -tokenizer = AutoTokenizer.from_pretrained(bert_path) -bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) -ssl_model = cnhubert.get_model() -if is_half: - bert_model = bert_model.half().to(device) - ssl_model = ssl_model.half().to(device) -else: - bert_model = bert_model.to(device) - ssl_model = ssl_model.to(device) -change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path) - - -# -------------------------------- -# 接口部分 -# -------------------------------- -app = FastAPI() - - -@app.post("/set_model") -async def set_model(request: Request): - json_post_raw = await request.json() - return change_gpt_sovits_weights( - gpt_path=json_post_raw.get("gpt_model_path"), sovits_path=json_post_raw.get("sovits_model_path") - ) - - -@app.get("/set_model") -async def set_model( - gpt_model_path: str = None, - sovits_model_path: str = None, -): - return change_gpt_sovits_weights(gpt_path=gpt_model_path, sovits_path=sovits_model_path) - - -@app.post("/control") -async def control(request: Request): - json_post_raw = await request.json() - return handle_control(json_post_raw.get("command")) - - -@app.get("/control") -async def control(command: str = None): - return handle_control(command) - - -@app.post("/change_refer") -async def change_refer(request: Request): - json_post_raw = await request.json() - return handle_change( - json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language") - ) - - -@app.get("/change_refer") -async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None): - return handle_change(refer_wav_path, prompt_text, prompt_language) - - -@app.post("/") -async def tts_endpoint(request: Request): - json_post_raw = await request.json() - return handle( - json_post_raw.get("refer_wav_path"), - json_post_raw.get("prompt_text"), - json_post_raw.get("prompt_language"), - json_post_raw.get("text"), - json_post_raw.get("text_language"), - json_post_raw.get("cut_punc"), - json_post_raw.get("top_k", 15), - json_post_raw.get("top_p", 1.0), - json_post_raw.get("temperature", 1.0), - json_post_raw.get("speed", 1.0), - json_post_raw.get("inp_refs", []), - json_post_raw.get("sample_steps", 32), - json_post_raw.get("if_sr", False), - ) - - -@app.get("/") -async def tts_endpoint( - refer_wav_path: str = None, - prompt_text: str = None, - prompt_language: str = None, - text: str = None, - text_language: str = None, - cut_punc: str = None, - top_k: int = 15, - top_p: float = 1.0, - temperature: float = 1.0, - speed: float = 1.0, - inp_refs: list = Query(default=[]), - sample_steps: int = 32, - if_sr: bool = False, -): - return handle( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - cut_punc, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, - ) - - -def _fmt_srt_time(t): - h = int(t // 3600) - m = int((t % 3600) // 60) - s = int(t % 60) - ms = int(round((t - int(t)) * 1000)) - return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" - - -def _build_norm_to_orig_map(orig, norm): - all_punc_local = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"} - mapping = [-1] * len(norm) - o = 0 - n = 0 - while n < len(norm) and o < len(orig): - if norm[n] == orig[o]: - mapping[n] = o - n += 1 - o += 1 - else: - if orig[o].isspace() or orig[o] in all_punc_local: - o += 1 - elif norm[n].isspace() or norm[n] in all_punc_local: - n += 1 - else: - n += 1 - return mapping - - -def synthesize_json( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, -): - infer_sovits = speaker_list["default"].sovits - vq_model = infer_sovits.vq_model - hps = infer_sovits.hps - version = vq_model.version - - if version in {"v3", "v4"}: - return JSONResponse({"code": 400, "message": "v3/v4 暂未提供时间戳JSON接口"}, status_code=400) - - # prepare refer features (same as handle) - dtype = torch.float16 if is_half else torch.float32 - zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half else np.float32) - with torch.no_grad(): - wav16k, sr = librosa.load(refer_wav_path, sr=16000) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - if is_half: - wav16k = wav16k.half().to(device) - zero_wav_torch = zero_wav_torch.half().to(device) - else: - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) - codes = vq_model.extract_latent(ssl_content) - prompt_semantic = codes[0, 0] - prompt = prompt_semantic.unsqueeze(0).to(device) - - is_v2pro = version in {"v2Pro", "v2ProPlus"} - refers = [] - if is_v2pro: - sv_emb = [] - if sv_cn_model == None: - init_sv_cn() - spec, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro) - refers = [spec] - if is_v2pro: - sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] - - # text frontend - prompt_language = dict_language[prompt_language.lower()] - text_language = dict_language[text_language.lower()] - phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) - texts = text.strip("\n").split("\n") - - # iterate - audio_segments = [] - timestamps_all = [] - sr_hz = int(hps.data.sampling_rate) - elapsed_s = 0.0 - for seg in texts: - if only_punc(seg): - continue - if seg[-1] not in splits: - seg += "。" if text_language != "en" else "." - phones2, bert2, norm_text2 = get_phones_and_bert(seg, text_language, version) - bert = torch.cat([bert1, bert2], 1) - all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - with torch.no_grad(): - pred_semantic, idx = speaker_list["default"].gpt.t2s_model.model.infer_panel( - all_phoneme_ids, all_phoneme_len, prompt, bert, top_k=top_k, top_p=top_p, temperature=temperature, early_stop_num=hz * speaker_list["default"].gpt.max_sec, - ) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) - phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0) - if is_v2pro: - o, attn, y_mask = vq_model.decode_with_alignment( - pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb - ) - else: - o, attn, y_mask = vq_model.decode_with_alignment( - pred_semantic, phones2_tensor, refers, speed=speed - ) - audio = o[0][0].detach().cpu().numpy() - # timestamps - frame_time = 0.02 / max(float(speed), 1e-6) - ph_spans = [] - if attn is not None: - attn_mean = attn.mean(dim=1)[0] - assign = attn_mean.argmax(dim=-1) - if assign.numel() > 0: - start_f = 0 - cur_ph = int(assign[0].item()) - for f in range(1, assign.shape[0]): - ph = int(assign[f].item()) - if ph != cur_ph: - ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": f * frame_time}) - start_f = f - cur_ph = ph - ph_spans.append({"phoneme_id": cur_ph, "start_s": start_f * frame_time, "end_s": assign.shape[0] * frame_time}) - - # char aggregation - _, word2ph, norm_text_seg = clean_text_inf(seg, text_language, version) - char_spans = [] - if word2ph: - ph_to_char = [] - for ch_idx, repeat in enumerate(word2ph): - ph_to_char += [ch_idx] * repeat - if ph_spans and ph_to_char: - for span in ph_spans: - ph_idx = span["phoneme_id"] - if 0 <= ph_idx < len(ph_to_char): - char_idx = ph_to_char[ph_idx] - char_spans.append({"char_index": char_idx, "char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", "start_s": span["start_s"], "end_s": span["end_s"]}) - # merge by char_index - if char_spans: - groups = {} - for cs in char_spans: - groups.setdefault(cs["char_index"], []).append(cs) - merged = [] - gap_merge_s = 0.08 - min_dur_s = max(0.015, frame_time) - for ci, lst in groups.items(): - lst = sorted(lst, key=lambda x: x["start_s"]) - cur = None - for it in lst: - if cur is None: - cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} - else: - if it["start_s"] - cur["end_s"] <= gap_merge_s: - if it["end_s"] > cur["end_s"]: - cur["end_s"] = it["end_s"] - else: - if cur["end_s"] - cur["start_s"] >= min_dur_s: - merged.append(cur) - cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} - if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s: - merged.append(cur) - # remap to original input text - norm2orig = _build_norm_to_orig_map(seg, norm_text_seg) - remapped = [] - punc = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"} - for cs in merged: - ci = cs["char_index"] - oi = norm2orig[ci] if ci < len(norm2orig) else -1 - ch_norm = cs.get("char", "") - if oi != -1: - cs["char"] = seg[oi] - remapped.append(cs) - else: - if ch_norm and (ch_norm in punc or ch_norm.isspace()): - continue - remapped.append(cs) - char_spans = sorted(remapped, key=lambda x: x["start_s"]) if remapped else [] - - # offset and record - audio_len_s = float(audio.shape[0]) / sr_hz - for coll in (char_spans,): - for d in coll: - d["start_s"] += elapsed_s - d["end_s"] += elapsed_s - timestamps_all.append({ - "segment_index": len(timestamps_all), - "char_spans": char_spans, - "text": norm_text_seg, - "segment_start_s": elapsed_s, - "segment_end_s": elapsed_s + audio_len_s, - }) - elapsed_s += audio_len_s + 0.3 - audio_segments.append(audio) - - # concatenate audio - if len(audio_segments) == 0: - return JSONResponse({"code": 400, "message": "无有效文本"}, status_code=400) - pad = np.zeros(int(sr_hz * 0.3), dtype=audio_segments[0].dtype) - out = [] - for i, a in enumerate(audio_segments): - out.append(a) - if i < len(audio_segments) - 1: - out.append(pad) - audio_np = np.concatenate(out, 0) - mx = np.abs(audio_np).max() - if mx > 1: - audio_np = audio_np / mx - - # srt - srt_lines = [] - idx_counter = 1 - for rec in timestamps_all: - if rec.get("char_spans"): - for c in rec["char_spans"]: - srt_lines.append(str(idx_counter)) - srt_lines.append(f"{_fmt_srt_time(c['start_s'])} --> { _fmt_srt_time(c['end_s'])}") - srt_lines.append(c.get("char", "")) - srt_lines.append("") - idx_counter += 1 - else: - st = rec.get("segment_start_s") - ed = rec.get("segment_end_s") - if st is not None and ed is not None: - srt_lines.append(str(idx_counter)) - srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") - srt_lines.append(rec.get("text", "")) - srt_lines.append("") - idx_counter += 1 - srt_text = "\n".join(srt_lines) - - # pack wav - import base64 - buf = BytesIO() - sf.write(buf, audio_np, sr_hz, format="WAV") - wav_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") - return JSONResponse({"code": 0, "sr": sr_hz, "audio_wav_base64": wav_b64, "timestamps": timestamps_all, "srt": srt_text}) - - -@app.post("/tts_json") -async def tts_json_post(request: Request): - body = await request.json() - return synthesize_json( - body.get("refer_wav_path"), - body.get("prompt_text"), - body.get("prompt_language"), - body.get("text"), - body.get("text_language"), - body.get("top_k", 15), - body.get("top_p", 0.6), - body.get("temperature", 0.6), - body.get("speed", 1.0), - body.get("inp_refs", []), - body.get("sample_steps", 32), - body.get("if_sr", False), - ) - - -@app.get("/tts_json") -async def tts_json_get( - refer_wav_path: str, - prompt_text: str, - prompt_language: str, - text: str, - text_language: str, - top_k: int = 15, - top_p: float = 0.6, - temperature: float = 0.6, - speed: float = 1.0, - inp_refs: list = Query(default=[]), - sample_steps: int = 32, - if_sr: bool = False, -): - return synthesize_json( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, - ) - - -if __name__ == "__main__": - uvicorn.run(app, host=host, port=port, workers=1) +""" +# api.py usage + +` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` + +## 执行参数: + +`-s` - `SoVITS模型路径, 可在 config.py 中指定` +`-g` - `GPT模型路径, 可在 config.py 中指定` + +调用请求缺少参考音频时使用 +`-dr` - `默认参考音频路径` +`-dt` - `默认参考音频文本` +`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"` + +`-d` - `推理设备, "cuda","cpu"` +`-a` - `绑定地址, 默认"127.0.0.1"` +`-p` - `绑定端口, 默认9880, 可在 config.py 中指定` +`-fp` - `覆盖 config.py 使用全精度` +`-hp` - `覆盖 config.py 使用半精度` +`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` +·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` +·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"` +·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入` + +`-hb` - `cnhubert路径` +`-b` - `bert路径` + +## 调用: + +### 推理 + +endpoint: `/` + +使用执行参数指定的参考音频: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +使用执行参数指定的参考音频并设定分割符号: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&cut_punc=,。` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "cut_punc": ",。", +} +``` + +手动指定当次推理所使用的参考音频: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +手动指定当次推理所使用的参考音频,并提供参数: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "top_k": 20, + "top_p": 0.6, + "temperature": 0.6, + "speed": 1, + "inp_refs": ["456.wav","789.wav"] +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 更换默认参考音频 + +endpoint: `/change_refer` + +key与推理端一样 + +GET: + `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh" +} +``` + +RESP: +成功: json, http code 200 +失败: json, 400 + + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: + `http://127.0.0.1:9880/control?command=restart` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + +""" + +import argparse +import json +import os +import re +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import signal +from text.LangSegmenter import LangSegmenter +from time import time as ttime +import torch +import torchaudio +import librosa +import soundfile as sf +from fastapi import FastAPI, Request, Query +from fastapi.responses import StreamingResponse, JSONResponse +import base64 +import io +import uvicorn +from transformers import AutoModelForMaskedLM, AutoTokenizer +import numpy as np +from feature_extractor import cnhubert +from io import BytesIO +from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3 +from peft import LoraConfig, get_peft_model +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from text import cleaned_text_to_sequence +from text.cleaner import clean_text +from module.mel_processing import spectrogram_torch +import config as global_config +import logging +import subprocess + + +class DefaultRefer: + def __init__(self, path, text, language): + self.path = args.default_refer_path + self.text = args.default_refer_text + self.language = args.default_refer_language + + def is_ready(self) -> bool: + return is_full(self.path, self.text, self.language) + + +def is_empty(*items): # 任意一项不为空返回False + for item in items: + if item is not None and item != "": + return False + return True + + +def is_full(*items): # 任意一项为空返回False + for item in items: + if item is None or item == "": + return False + return True + + +bigvgan_model = hifigan_model = sv_cn_model = None + + +def clean_hifigan_model(): + global hifigan_model + if hifigan_model: + hifigan_model = hifigan_model.cpu() + hifigan_model = None + try: + torch.cuda.empty_cache() + except: + pass + + +def clean_bigvgan_model(): + global bigvgan_model + if bigvgan_model: + bigvgan_model = bigvgan_model.cpu() + bigvgan_model = None + try: + torch.cuda.empty_cache() + except: + pass + + +def clean_sv_cn_model(): + global sv_cn_model + if sv_cn_model: + sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu() + sv_cn_model = None + try: + torch.cuda.empty_cache() + except: + pass + + +def init_bigvgan(): + global bigvgan_model, hifigan_model, sv_cn_model + from BigVGAN import bigvgan + + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +def init_hifigan(): + global hifigan_model, bigvgan_model, sv_cn_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, + is_bias=True, + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load( + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), + map_location="cpu", + weights_only=False, + ) + print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) + + +from sv import SV + + +def init_sv_cn(): + global hifigan_model, bigvgan_model, sv_cn_model + sv_cn_model = SV(device, is_half) + + +resample_transform_dict = {} + + +def resample(audio_tensor, sr0, sr1, device): + global resample_transform_dict + key = "%s-%s-%s" % (sr0, sr1, str(device)) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) + + +from module.mel_processing import mel_spectrogram_torch + +spec_min = -12 +spec_max = 2 + + +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + + +sr_model = None + + +def audio_sr(audio, sr): + global sr_model + if sr_model == None: + from tools.audio_sr import AP_BWE + + try: + sr_model = AP_BWE(device, DictToAttrRecursive) + except FileNotFoundError: + logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载") + return audio.cpu().detach().numpy(), sr + return sr_model(audio, sr) + + +class Speaker: + def __init__(self, name, gpt, sovits, phones=None, bert=None, prompt=None): + self.name = name + self.sovits = sovits + self.gpt = gpt + self.phones = phones + self.bert = bert + self.prompt = prompt + + +speaker_list = {} + + +class Sovits: + def __init__(self, vq_model, hps): + self.vq_model = vq_model + self.hps = hps + + +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new + + +def get_sovits_weights(sovits_path): + from config import pretrained_sovits_name + + path_sovits_v3 = pretrained_sovits_name["v3"] + path_sovits_v4 = pretrained_sovits_name["v4"] + is_exist_s2gv3 = os.path.exists(path_sovits_v3) + is_exist_s2gv4 = os.path.exists(path_sovits_v4) + + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + + if if_lora_v3 == True and is_exist == False: + logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version) + + dict_s2 = load_sovits_new(sovits_path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps.model.version = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + + model_params_dict = vars(hps.model) + if model_version not in {"v3", "v4"}: + if "Pro" in model_version: + hps.model.version = model_version + if sv_cn_model == None: + init_sv_cn() + + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict, + ) + else: + hps.model.version = model_version + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict, + ) + if model_version == "v3": + init_bigvgan() + if model_version == "v4": + init_hifigan() + + model_version = hps.model.version + logger.info(f"模型版本: {model_version}") + if "pretrained" not in sovits_path: + try: + del vq_model.enc_q + except: + pass + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + if if_lora_v3 == False: + vq_model.load_state_dict(dict_s2["weight"], strict=False) + else: + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False) + lora_rank = dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() + + sovits = Sovits(vq_model, hps) + return sovits + + +class Gpt: + def __init__(self, max_sec, t2s_model): + self.max_sec = max_sec + self.t2s_model = t2s_model + + +global hz +hz = 50 + + +def get_gpt_weights(gpt_path): + dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # logger.info("Number of parameter: %.2fM" % (total / 1e6)) + + gpt = Gpt(max_sec, t2s_model) + return gpt + + +def change_gpt_sovits_weights(gpt_path, sovits_path): + try: + gpt = get_gpt_weights(gpt_path) + sovits = get_sovits_weights(sovits_path) + except Exception as e: + return JSONResponse({"code": 400, "message": str(e)}, status_code=400) + + speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + # if(is_half==True):phone_level_feature=phone_level_feature.half() + return phone_level_feature.T + + +def clean_text_inf(text, language, version): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + +def get_bert_inf(phones, word2ph, norm_text, language): + language = language.replace("all_", "") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +from text import chinese + + +def get_phones_and_bert(text, language, version, final=False): + text = re.sub(r' {2,}', ' ', text) + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text,"zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text,"zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text,"ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text,"ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text + + +def _viterbi_monotonic(p: torch.Tensor): + T, N = p.shape + if T < N: + return p.argmax(dim=-1) + eps = 1e-8 + cost = -torch.log(p + eps) + dp = torch.empty((T, N), dtype=cost.dtype, device=cost.device) + prev = torch.zeros((T, N), dtype=torch.uint8, device=cost.device) + dp[0, 0] = cost[0, 0] + if N > 1: + dp[0, 1:] = float("inf") + for t_i in range(1, T): + dp[t_i, 0] = dp[t_i - 1, 0] + cost[t_i, 0] + prev[t_i, 0] = 0 + if N > 1: + stay = dp[t_i - 1, 1:] + cost[t_i, 1:] + move = dp[t_i - 1, :-1] + cost[t_i, 1:] + better_move = move < stay + dp[t_i, 1:] = torch.where(better_move, move, stay) + prev[t_i, 1:] = better_move.to(torch.uint8) + j = N - 1 + assign_bt = torch.empty(T, dtype=torch.long, device=cost.device) + for t_i in range(T - 1, -1, -1): + assign_bt[t_i] = j + if t_i > 0 and prev[t_i, j] == 1: + j = max(j - 1, 0) + return assign_bt + + +def _build_mixed_mappings_for_api(_text: str, _ui_lang: str, _version: str): + _text = re.sub(r' {2,}', ' ', _text) + textlist = [] + langlist = [] + if _ui_lang == "all_zh": + for tmp in LangSegmenter.getTexts(_text, "zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "all_yue": + for tmp in LangSegmenter.getTexts(_text, "zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "all_ja": + for tmp in LangSegmenter.getTexts(_text, "ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "all_ko": + for tmp in LangSegmenter.getTexts(_text, "ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "en": + langlist.append("en") + textlist.append(_text) + elif _ui_lang == "auto": + for tmp in LangSegmenter.getTexts(_text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif _ui_lang == "auto_yue": + for tmp in LangSegmenter.getTexts(_text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(_text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + langlist.append(_ui_lang) + textlist.append(tmp["text"]) + + ph_to_char = [] + ph_to_word = [] + word_tokens = [] + norm_text_agg = [] + import re as _re + for seg_text, seg_lang in zip(textlist, langlist): + seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version) + norm_text_agg.append(seg_norm) + if seg_lang in {"zh", "yue", "ja"} and seg_word2ph: + char_base_idx = len("".join(norm_text_agg[:-1])) + for ch_idx, cnt in enumerate(seg_word2ph): + global_char_idx = char_base_idx + ch_idx + ph_to_char += [global_char_idx] * cnt + ph_to_word += [-1] * len(seg_phones) + elif seg_lang in {"en", "ko"} and seg_word2ph: + tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in splits) for c in t)] + base_word_idx = len(word_tokens) + t_idx = 0 + for cnt in seg_word2ph: + if t_idx < len(tokens_seg): + ph_to_word += [base_word_idx + t_idx] * cnt + t_idx += 1 + ph_to_char += [-1] * len(seg_phones) + word_tokens.extend(tokens_seg) + else: + ph_to_char += [-1] * len(seg_phones) + ph_to_word += [-1] * len(seg_phones) + norm_text_agg = "".join(norm_text_agg) + return norm_text_agg, ph_to_char, ph_to_word, word_tokens + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +def get_spepc(hps, filename, dtype, device, is_v2pro=False): + sr1 = int(hps.data.sampling_rate) + audio, sr0 = torchaudio.load(filename) + if sr0 != sr1: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, sr0, sr1, device) + else: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + + maxx = audio.abs().max() + if maxx > 1: + audio /= min(2, maxx) + spec = spectrogram_torch( + audio, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + spec = spec.to(dtype) + if is_v2pro == True: + audio = resample(audio, sr1, 16000, device).to(dtype) + return spec, audio + + +def pack_audio(audio_bytes, data, rate): + if media_type == "ogg": + audio_bytes = pack_ogg(audio_bytes, data, rate) + elif media_type == "aac": + audio_bytes = pack_aac(audio_bytes, data, rate) + else: + # wav无法流式, 先暂存raw + audio_bytes = pack_raw(audio_bytes, data, rate) + + return audio_bytes + + +def pack_ogg(audio_bytes, data, rate): + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + import threading + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + + return audio_bytes + + +def pack_raw(audio_bytes, data, rate): + audio_bytes.write(data.tobytes()) + + return audio_bytes + + +def pack_wav(audio_bytes, rate): + if is_int32: + data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int32) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format="WAV", subtype="PCM_32") + else: + data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format="WAV") + return wav_bytes + + +def pack_aac(audio_bytes, data, rate): + if is_int32: + pcm = "s32le" + bit_rate = "256k" + else: + pcm = "s16le" + bit_rate = "128k" + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + pcm, # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + bit_rate, # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + audio_bytes.write(out) + + return audio_bytes + + +def read_clean_buffer(audio_bytes): + audio_chunk = audio_bytes.getvalue() + audio_bytes.truncate(0) + audio_bytes.seek(0) + + return audio_bytes, audio_chunk + + +def cut_text(text, punc): + punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}] + if len(punc_list) > 0: + punds = r"[" + "".join(punc_list) + r"]" + text = text.strip("\n") + items = re.split(f"({punds})", text) + mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] + # 在句子不存在符号或句尾无符号的时候保证文本完整 + if len(items) % 2 == 1: + mergeitems.append(items[-1]) + text = "\n".join(mergeitems) + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + + return text + + +def only_punc(text): + return not any(t.isalnum() or t.isalpha() for t in text) + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} + + +def get_tts_wav( + ref_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k=15, + top_p=0.6, + temperature=0.6, + speed=1, + inp_refs=None, + sample_steps=32, + if_sr=False, + spk="default", +): + infer_sovits = speaker_list[spk].sovits + vq_model = infer_sovits.vq_model + hps = infer_sovits.hps + version = vq_model.version + + infer_gpt = speaker_list[spk].gpt + t2s_model = infer_gpt.t2s_model + max_sec = infer_gpt.max_sec + + if version == "v3": + if sample_steps not in [4, 8, 16, 32, 64, 128]: + sample_steps = 32 + elif version == "v4": + if sample_steps not in [4, 8, 16, 32]: + sample_steps = 8 + + if if_sr and version != "v3": + if_sr = False + + t0 = ttime() + prompt_text = prompt_text.strip("\n") + if prompt_text[-1] not in splits: + prompt_text += "。" if prompt_language != "en" else "." + prompt_language, text = prompt_language, text.strip("\n") + dtype = torch.float16 if is_half == True else torch.float32 + zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + is_v2pro = version in {"v2Pro", "v2ProPlus"} + if version not in {"v3", "v4"}: + refers = [] + if is_v2pro: + sv_emb = [] + if sv_cn_model == None: + init_sv_cn() + if inp_refs: + for path in inp_refs: + try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer + refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) + refers.append(refer) + if is_v2pro: + sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) + except Exception as e: + logger.error(e) + if len(refers) == 0: + refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) + refers = [refers] + if is_v2pro: + sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] + else: + refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) + + t1 = ttime() + # os.environ['version'] = version + prompt_language = dict_language[prompt_language.lower()] + text_language = dict_language[text_language.lower()] + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) + texts = text.split("\n") + audio_bytes = BytesIO() + + for text in texts: + # 简单防止纯符号引发参考音频泄露 + if only_punc(text): + continue + + audio_opt = [] + if text[-1] not in splits: + text += "。" if text_language != "en" else "." + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) + bert = torch.cat([bert1, bert2], 1) + + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + t2 = ttime() + with torch.no_grad(): + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + t3 = ttime() + + if version not in {"v3", "v4"}: + if is_v2pro: + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, + torch.LongTensor(phones2).to(device).unsqueeze(0), + refers, + speed=speed, + sv_emb=sv_emb, + ) + audio = o.detach().cpu().numpy()[0, 0] + else: + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed + ) + audio = o.detach().cpu().numpy()[0, 0] + else: + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + tgt_sr = 24000 if version == "v3" else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr, tgt_sr, device) + mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + Tref = 468 if version == "v3" else 500 + Tchunk = 934 if version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min + mel2 = mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = vq_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + if version == "v3": + if bigvgan_model == None: + init_bigvgan() + else: # v4 + if hifigan_model == None: + init_hifigan() + vocoder_model = bigvgan_model if version == "v3" else hifigan_model + with torch.inference_mode(): + wav_gen = vocoder_model(cfm_res) + audio = wav_gen[0][0].cpu().detach().numpy() + + max_audio = np.abs(audio).max() + if max_audio > 1: + audio /= max_audio + audio_opt.append(audio) + audio_opt.append(zero_wav) + audio_opt = np.concatenate(audio_opt, 0) + t4 = ttime() + + if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + sr = 32000 + elif version == "v3": + sr = 24000 + else: + sr = 48000 # v4 + + if if_sr and sr == 24000: + audio_opt = torch.from_numpy(audio_opt).float().to(device) + audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) + max_audio = np.abs(audio_opt).max() + if max_audio > 1: + audio_opt /= max_audio + sr = 48000 + + if is_int32: + audio_bytes = pack_audio(audio_bytes, (audio_opt * 2147483647).astype(np.int32), sr) + else: + audio_bytes = pack_audio(audio_bytes, (audio_opt * 32768).astype(np.int16), sr) + # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + if stream_mode == "normal": + audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) + # For backward compatibility, yield audio chunk only + yield audio_chunk + + if not stream_mode == "normal": + if media_type == "wav": + if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + sr = 32000 + elif version == "v3": + sr = 48000 if if_sr else 24000 + else: + sr = 48000 # v4 + audio_bytes = pack_wav(audio_bytes, sr) + # extend: for stream_mode!=normal we still return only audio bytes for backward compatibility + yield audio_bytes.getvalue() + + +def handle_control(command): + if command == "restart": + os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def handle_change(path, text, language): + if is_empty(path, text, language): + return JSONResponse( + {"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400 + ) + + if path != "" or path is not None: + default_refer.path = path + if text != "" or text is not None: + default_refer.text = text + if language != "" or language is not None: + default_refer.language = language + + logger.info(f"当前默认参考音频路径: {default_refer.path}") + logger.info(f"当前默认参考音频文本: {default_refer.text}") + logger.info(f"当前默认参考音频语种: {default_refer.language}") + logger.info(f"is_ready: {default_refer.is_ready()}") + + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def handle( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + cut_punc, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, +): + if ( + refer_wav_path == "" + or refer_wav_path is None + or prompt_text == "" + or prompt_text is None + or prompt_language == "" + or prompt_language is None + ): + refer_wav_path, prompt_text, prompt_language = ( + default_refer.path, + default_refer.text, + default_refer.language, + ) + if not default_refer.is_ready(): + return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + + if cut_punc == None: + text = cut_text(text, default_cut_punc) + else: + text = cut_text(text, cut_punc) + + gen = get_tts_wav( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, + ) + # Consume the generator to collect bytes and timestamps; pack as multipart/mixed JSON with audio for compatibility + # For simplicity, keep legacy streaming behaviour + return StreamingResponse(gen, media_type="audio/" + media_type) + + +## /v1/tts_json endpoint moved below after app is initialized + + +# -------------------------------- +# 初始化部分 +# -------------------------------- +dict_language = { + "中文": "all_zh", + "粤语": "all_yue", + "英文": "en", + "日文": "all_ja", + "韩文": "all_ko", + "中英混合": "zh", + "粤英混合": "yue", + "日英混合": "ja", + "韩英混合": "ko", + "多语种混合": "auto", # 多语种启动切分识别语种 + "多语种混合(粤语)": "auto_yue", + "all_zh": "all_zh", + "all_yue": "all_yue", + "en": "en", + "all_ja": "all_ja", + "all_ko": "all_ko", + "zh": "zh", + "yue": "yue", + "ja": "ja", + "ko": "ko", + "auto": "auto", + "auto_yue": "auto_yue", +} + +# logger +logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) +logger = logging.getLogger("uvicorn") + +# 获取配置 +g_config = global_config.Config() + +# 获取参数 +parser = argparse.ArgumentParser(description="GPT-SoVITS api") + +parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") +parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") +parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") +parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") +parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") +parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") +parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") +parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") +parser.add_argument( + "-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度" +) +parser.add_argument( + "-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度" +) +# bool值的用法为 `python ./api.py -fp ...` +# 此时 full_precision==True, half_precision==False +parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") +parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") +parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32") +parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") +# 切割常用分句符为 `python ./api.py -cp ".?!。?!"` +parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") +parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") + +args = parser.parse_args() +sovits_path = args.sovits_path +gpt_path = args.gpt_path +device = args.device +port = args.port +host = args.bind_addr +cnhubert_base_path = args.hubert_path +bert_path = args.bert_path +default_cut_punc = args.cut_punc + +# 应用参数配置 +default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) + +# 模型路径检查 +if sovits_path == "": + sovits_path = g_config.pretrained_sovits_path + logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}") +if gpt_path == "": + gpt_path = g_config.pretrained_gpt_path + logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}") + +# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 +if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": + default_refer.path, default_refer.text, default_refer.language = "", "", "" + logger.info("未指定默认参考音频") +else: + logger.info(f"默认参考音频路径: {default_refer.path}") + logger.info(f"默认参考音频文本: {default_refer.text}") + logger.info(f"默认参考音频语种: {default_refer.language}") + +# 获取半精度 +is_half = g_config.is_half +if args.full_precision: + is_half = False +if args.half_precision: + is_half = True +if args.full_precision and args.half_precision: + is_half = g_config.is_half # 炒饭fallback +logger.info(f"半精: {is_half}") + +# 流式返回模式 +if args.stream_mode.lower() in ["normal", "n"]: + stream_mode = "normal" + logger.info("流式返回已开启") +else: + stream_mode = "close" + +# 音频编码格式 +if args.media_type.lower() in ["aac", "ogg"]: + media_type = args.media_type.lower() +elif stream_mode == "close": + media_type = "wav" +else: + media_type = "ogg" +logger.info(f"编码格式: {media_type}") + +# 音频数据类型 +if args.sub_type.lower() == "int32": + is_int32 = True + logger.info("数据类型: int32") +else: + is_int32 = False + logger.info("数据类型: int16") + +# 初始化模型 +cnhubert.cnhubert_base_path = cnhubert_base_path +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +ssl_model = cnhubert.get_model() +if is_half: + bert_model = bert_model.half().to(device) + ssl_model = ssl_model.half().to(device) +else: + bert_model = bert_model.to(device) + ssl_model = ssl_model.to(device) +change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path) + + +# -------------------------------- +# 接口部分 +# -------------------------------- +app = FastAPI() + + +@app.post("/set_model") +async def set_model(request: Request): + json_post_raw = await request.json() + return change_gpt_sovits_weights( + gpt_path=json_post_raw.get("gpt_model_path"), sovits_path=json_post_raw.get("sovits_model_path") + ) + + +@app.get("/set_model") +async def set_model( + gpt_model_path: str = None, + sovits_model_path: str = None, +): + return change_gpt_sovits_weights(gpt_path=gpt_model_path, sovits_path=sovits_model_path) + + +@app.post("/control") +async def control(request: Request): + json_post_raw = await request.json() + return handle_control(json_post_raw.get("command")) + + +@app.get("/control") +async def control(command: str = None): + return handle_control(command) + + +@app.post("/change_refer") +async def change_refer(request: Request): + json_post_raw = await request.json() + return handle_change( + json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language") + ) + + +@app.get("/change_refer") +async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None): + return handle_change(refer_wav_path, prompt_text, prompt_language) + + +@app.post("/") +async def tts_endpoint(request: Request): + json_post_raw = await request.json() + return handle( + json_post_raw.get("refer_wav_path"), + json_post_raw.get("prompt_text"), + json_post_raw.get("prompt_language"), + json_post_raw.get("text"), + json_post_raw.get("text_language"), + json_post_raw.get("cut_punc"), + json_post_raw.get("top_k", 15), + json_post_raw.get("top_p", 1.0), + json_post_raw.get("temperature", 1.0), + json_post_raw.get("speed", 1.0), + json_post_raw.get("inp_refs", []), + json_post_raw.get("sample_steps", 32), + json_post_raw.get("if_sr", False), + ) + + +@app.get("/") +async def tts_endpoint( + refer_wav_path: str = None, + prompt_text: str = None, + prompt_language: str = None, + text: str = None, + text_language: str = None, + cut_punc: str = None, + top_k: int = 15, + top_p: float = 1.0, + temperature: float = 1.0, + speed: float = 1.0, + inp_refs: list = Query(default=[]), + sample_steps: int = 32, + if_sr: bool = False, +): + return handle( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + cut_punc, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, + ) + + +@app.post("/v1/tts_json") +async def tts_json_post(request: Request): + data = await request.json() + try: + # 直接进行一次非流式推理并输出音频+时间戳(v2/v2Pro/v2ProPlus带对齐) + infer_sovits = speaker_list["default"].sovits + vq_model = infer_sovits.vq_model + hps = infer_sovits.hps + version = vq_model.version + infer_gpt = speaker_list["default"].gpt + t2s_model = infer_gpt.t2s_model + max_sec = infer_gpt.max_sec + + refer_wav_path = data.get("refer_wav_path") + prompt_text = data.get("prompt_text") or "" + prompt_language = dict_language[data.get("prompt_language", "zh").lower()] + text = data.get("text") or "" + text_language = dict_language[data.get("text_language", "zh").lower()] + top_k = int(data.get("top_k", 15)) + top_p = float(data.get("top_p", 1.0)) + temperature = float(data.get("temperature", 1.0)) + speed = float(data.get("speed", 1.0)) + inp_refs = data.get("inp_refs", []) + sample_steps = int(data.get("sample_steps", 32)) + if_sr = bool(data.get("if_sr", False)) + + if is_empty(refer_wav_path, prompt_text, prompt_language): + refer_wav_path, prompt_text, prompt_language = ( + default_refer.path, + default_refer.text, + default_refer.language, + ) + if not default_refer.is_ready(): + return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + + prompt_text = prompt_text.strip("\n") + if len(prompt_text) > 0 and prompt_text[-1] not in splits: + prompt_text += "。" if prompt_language != "en" else "." + text = text.strip("\n") + if len(text) > 0 and text[-1] not in splits: + text += "。" if text_language != "en" else "." + + dtype = torch.float16 if is_half == True else torch.float32 + with torch.no_grad(): + wav16k, sr_r = librosa.load(refer_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version) + phones2, bert2, norm_text = get_phones_and_bert(text, text_language, version) + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + + with torch.no_grad(): + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + prompt, + bert, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + + is_v2pro = version in {"v2Pro", "v2ProPlus"} + attn = None + if version not in {"v3", "v4"}: + refers = [] + if is_v2pro: + sv_emb = [] + if sv_cn_model == None: + init_sv_cn() + if inp_refs: + for path in inp_refs: + try: + refer, audio_tensor = get_spepc(hps, path, dtype, device, is_v2pro) + refers.append(refer) + if is_v2pro: + sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) + except Exception as e: + logger.error(e) + if len(refers) == 0: + refers, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device, is_v2pro) + refers = [refers] + if is_v2pro: + sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] + phones2_tensor = torch.LongTensor(phones2).to(device).unsqueeze(0) + if is_v2pro: + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, phones2_tensor, refers, speed=speed, sv_emb=sv_emb + ) + else: + o, attn, y_mask = vq_model.decode_with_alignment( + pred_semantic, phones2_tensor, refers, speed=speed + ) + audio = o[0][0].detach().cpu().numpy() + else: + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + refer, audio_tensor = get_spepc(hps, refer_wav_path, dtype, device) + ref_audio, sr0 = torchaudio.load(refer_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + tgt_sr = 24000 if version == "v3" else 32000 + if sr0 != tgt_sr: + ref_audio = resample(ref_audio, sr0, tgt_sr, device) + mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + Tref = 468 if version == "v3" else 500 + Tchunk = 934 if version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min + mel2 = mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) + cfm_resss = [] + idx_p = 0 + while True: + fea_todo_chunk = fea_todo[:, :, idx_p : idx_p + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx_p += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = vq_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + if version == "v3": + if bigvgan_model == None: + init_bigvgan() + else: + if hifigan_model == None: + init_hifigan() + vocoder_model = bigvgan_model if version == "v3" else hifigan_model + with torch.inference_mode(): + wav_gen = vocoder_model(cfm_res) + audio = wav_gen[0][0].cpu().detach().numpy() + + max_audio = np.abs(audio).max() + if max_audio > 1: + audio = audio / max_audio + + if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + sr = 32000 + elif version == "v3": + sr = 24000 + else: + sr = 48000 + if if_sr and sr == 24000: + audio_t = torch.from_numpy(audio).float().to(device) + audio_t, sr = audio_sr(audio_t.unsqueeze(0), sr) + audio = audio_t.cpu().detach().numpy()[0] + max_audio = np.abs(audio).max() + if max_audio > 1: + audio /= max_audio + sr = 48000 + + timestamps_all = [] + if attn is not None: + attn_heads_mean = attn.mean(dim=1)[0] + assign = _viterbi_monotonic(attn_heads_mean) + frame_time = 0.02 / max(speed, 1e-6) + ph_spans = [] + if assign.numel() > 0: + start_f = 0 + cur_ph = int(assign[0].item()) + for f in range(1, assign.shape[0]): + ph = int(assign[f].item()) + if ph != cur_ph: + ph_spans.append({ + "phoneme_id": cur_ph, + "start_s": start_f * frame_time, + "end_s": f * frame_time, + }) + start_f = f + cur_ph = ph + ph_spans.append({ + "phoneme_id": cur_ph, + "start_s": start_f * frame_time, + "end_s": assign.shape[0] * frame_time, + }) + + norm_text_seg, ph_to_char_map, ph_to_word_map, word_tokens = _build_mixed_mappings_for_api(text, text_language, version) + + char_spans = [] + if ph_spans and ph_to_char_map and len(ph_to_char_map) >= 1: + for span in ph_spans: + ph_idx = span["phoneme_id"] + if 0 <= ph_idx < len(ph_to_char_map): + ci = ph_to_char_map[ph_idx] + if ci == -1: + continue + if len(char_spans) == 0 or char_spans[-1]["char_index"] != ci: + char_spans.append({ + "char_index": ci, + "char": norm_text_seg[ci] if ci < len(norm_text_seg) else "", + "start_s": span["start_s"], + "end_s": span["end_s"], + }) + else: + char_spans[-1]["end_s"] = span["end_s"] + word_spans = [] + if ph_spans and ph_to_word_map and len(ph_to_word_map) >= 1: + for span in ph_spans: + ph_idx = span["phoneme_id"] + if 0 <= ph_idx < len(ph_to_word_map): + wi = ph_to_word_map[ph_idx] + if wi == -1: + continue + if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi: + word_spans.append({ + "word_index": wi, + "word": word_tokens[wi] if wi < len(word_tokens) else "", + "start_s": span["start_s"], + "end_s": span["end_s"], + }) + else: + word_spans[-1]["end_s"] = span["end_s"] + + audio_len_s = float(len(audio)) / float(sr) + timestamps_all.append({ + "segment_index": 0, + "phoneme_spans": ph_spans, + "char_spans": char_spans, + "word_spans": word_spans, + "segment_start_s": 0.0, + "segment_end_s": audio_len_s, + "text": norm_text_seg, + }) + + import soundfile as sf + bio = io.BytesIO() + sf.write(bio, (audio * 32767).astype("int16"), sr, format="WAV", subtype="PCM_16") + wav_b64 = base64.b64encode(bio.getvalue()).decode("ascii") + + def _fmt_srt_time(t): + h = int(t // 3600) + m = int((t % 3600) // 60) + s = int(t % 60) + ms = int(round((t - int(t)) * 1000)) + return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" + + srt_b64 = None + if timestamps_all: + srt_lines = [] + idx_counter = 1 + for rec in timestamps_all: + cs = rec.get("char_spans") or [] + ws = rec.get("word_spans") or [] + entries = [] + for c in cs: + entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]}) + for w in ws: + entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]}) + if entries: + entries.sort(key=lambda x: x["start"]) + for e in entries: + srt_lines.append(str(idx_counter)) + srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}") + srt_lines.append(e["text"]) + srt_lines.append("") + idx_counter += 1 + srt_b64 = base64.b64encode("\n".join(srt_lines).encode("utf-8")).decode("ascii") + + timestamps_json_b64 = base64.b64encode(json.dumps(timestamps_all, ensure_ascii=False).encode("utf-8")).decode("ascii") + + resp = { + "sr": sr, + "audio_wav_base64": wav_b64, + "timestamps": timestamps_all, + "timestamps_json_base64": timestamps_json_b64, + } + if srt_b64: + resp["srt_base64"] = srt_b64 + return JSONResponse({"code": 0, "data": resp}) + except Exception as e: + return JSONResponse({"code": 400, "message": str(e)}, status_code=400) + + +if __name__ == "__main__": + uvicorn.run(app, host=host, port=port, workers=1)