Supports phoneme and word-level timestamp output for multilingual text

This commit is contained in:
白菜工厂1145号员工 2025-09-22 07:43:59 +08:00
parent 2ff9a1533f
commit 19ca3f3f6a
6 changed files with 2473 additions and 2169 deletions

View File

@ -932,7 +932,40 @@ def get_tts_wav(
# derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640) # derive phoneme-level timestamps (20ms per frame at 32kHz, hop=640)
if last_attn is not None: if last_attn is not None:
attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text] attn_heads_mean = last_attn.mean(dim=1)[0] # [T_ssl, T_text]
assign = attn_heads_mean.argmax(dim=-1) # [T_ssl] # Viterbi monotonic alignment (stay or advance by 1)
def _viterbi_monotonic(p):
T, N = p.shape
if T < N:
return p.argmax(dim=-1)
eps = 1e-8
cost = -torch.log(p + eps)
dp = torch.empty((T, N), dtype=cost.dtype, device=cost.device)
prev = torch.zeros((T, N), dtype=torch.uint8, device=cost.device)
dp[0, 0] = cost[0, 0]
if N > 1:
dp[0, 1:] = float("inf")
for t_i in range(1, T):
# j=0: only stay
dp[t_i, 0] = dp[t_i - 1, 0] + cost[t_i, 0]
prev[t_i, 0] = 0
if N > 1:
stay = dp[t_i - 1, 1:] + cost[t_i, 1:]
move = dp[t_i - 1, :-1] + cost[t_i, 1:]
better_move = move < stay
dp[t_i, 1:] = torch.where(better_move, move, stay)
prev[t_i, 1:] = better_move.to(torch.uint8)
# backtrack from (T-1, N-1)
j = N - 1
assign_bt = torch.empty(T, dtype=torch.long, device=cost.device)
for t_i in range(T - 1, -1, -1):
assign_bt[t_i] = j
if t_i > 0 and prev[t_i, j] == 1:
j = j - 1
if j < 0:
j = 0
return assign_bt
assign = _viterbi_monotonic(attn_heads_mean)
frame_time = 0.02 / max(speed, 1e-6) frame_time = 0.02 / max(speed, 1e-6)
# collapse consecutive frames pointing to same phoneme id # collapse consecutive frames pointing to same phoneme id
ph_spans = [] ph_spans = []
@ -954,122 +987,186 @@ def get_tts_wav(
"start_s": start_f * frame_time, "start_s": start_f * frame_time,
"end_s": assign.shape[0] * frame_time, "end_s": assign.shape[0] * frame_time,
}) })
# char/word aggregation # char/word aggregation for multi-language text
# obtain word2ph for current text segment def _build_mixed_mappings(_text, _ui_lang, _version):
_, word2ph, norm_text_seg = clean_text_inf(text, text_language, version) # replicate segmentation logic from get_phones_and_bert
# char spans (for zh/yue where word2ph is char-based) _text = re.sub(r' {2,}', ' ', _text)
char_spans = [] textlist = []
if word2ph: langlist = []
ph_to_char = [] if _ui_lang == "all_zh":
for ch_idx, repeat in enumerate(word2ph): for tmp in LangSegmenter.getTexts(_text, "zh"):
ph_to_char += [ch_idx] * repeat langlist.append(tmp["lang"])
if ph_spans and ph_to_char: textlist.append(tmp["text"])
for span in ph_spans: elif _ui_lang == "all_yue":
ph_idx = span["phoneme_id"] for tmp in LangSegmenter.getTexts(_text, "zh"):
if 0 <= ph_idx < len(ph_to_char): if tmp["lang"] == "zh":
char_idx = ph_to_char[ph_idx] tmp["lang"] = "yue"
if len(char_spans) == 0 or char_spans[-1]["char_index"] != char_idx: langlist.append(tmp["lang"])
char_spans.append({ textlist.append(tmp["text"])
"char_index": char_idx, elif _ui_lang == "all_ja":
"char": norm_text_seg[char_idx] if char_idx < len(norm_text_seg) else "", for tmp in LangSegmenter.getTexts(_text, "ja"):
"start_s": span["start_s"], langlist.append(tmp["lang"])
"end_s": span["end_s"], textlist.append(tmp["text"])
}) elif _ui_lang == "all_ko":
else: for tmp in LangSegmenter.getTexts(_text, "ko"):
char_spans[-1]["end_s"] = span["end_s"] langlist.append(tmp["lang"])
# post-merge by char_index across the whole segment to remove jitter fragments textlist.append(tmp["text"])
if char_spans: elif _ui_lang == "en":
# group by char_index langlist.append("en")
groups = {} textlist.append(_text)
for cs in char_spans: elif _ui_lang == "auto":
groups.setdefault(cs["char_index"], []).append(cs) for tmp in LangSegmenter.getTexts(_text):
merged = [] langlist.append(tmp["lang"])
gap_merge_s = 0.08 textlist.append(tmp["text"])
# adaptive minimal duration: at least one frame, but not lower than 15ms elif _ui_lang == "auto_yue":
min_dur_s = max(0.015, frame_time) for tmp in LangSegmenter.getTexts(_text):
for ci, lst in groups.items(): if tmp["lang"] == "zh":
lst = sorted(lst, key=lambda x: x["start_s"]) tmp["lang"] = "yue"
cur = None langlist.append(tmp["lang"])
for it in lst: textlist.append(tmp["text"])
if cur is None: else:
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]} for tmp in LangSegmenter.getTexts(_text):
else: if langlist:
if it["start_s"] - cur["end_s"] <= gap_merge_s: if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
if it["end_s"] > cur["end_s"]: textlist[-1] += tmp["text"]
cur["end_s"] = it["end_s"]
else:
if cur["end_s"] - cur["start_s"] >= min_dur_s:
merged.append(cur)
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
merged.append(cur)
# sort merged by time
char_spans = sorted(merged, key=lambda x: x["start_s"])
# remap normalized chars back to original input text to avoid spurious '.'/'?'
def _build_norm_to_orig_map(orig, norm):
all_punc_local = set(punctuation).union(set(splits))
mapping = [-1] * len(norm)
o = 0
n = 0
while n < len(norm) and o < len(orig):
if norm[n] == orig[o]:
mapping[n] = o
n += 1
o += 1
else:
# skip spaces/punctuations on either side
if orig[o].isspace() or orig[o] in all_punc_local:
o += 1
elif norm[n].isspace() or norm[n] in all_punc_local:
n += 1
else:
# characters differ (e.g., compatibility form). Advance normalized index.
n += 1
return mapping
norm2orig = _build_norm_to_orig_map(text, norm_text_seg)
all_punc_local = set(punctuation).union(set(splits))
remapped = []
for cs in char_spans:
ci = cs["char_index"]
ch_norm = cs.get("char", "")
oi = norm2orig[ci] if ci < len(norm2orig) else -1
if oi != -1:
cs["char"] = text[oi]
remapped.append(cs)
else:
# drop normalized-only punctuations/spaces not present in original
if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()):
continue continue
remapped.append(cs) if tmp["lang"] == "en":
char_spans = remapped langlist.append(tmp["lang"])
# word spans else:
word_spans = [] langlist.append(_ui_lang)
if text_language == "en": textlist.append(tmp["text"])
# build phoneme-to-word map using dictionary per-word g2p
try: # aggregate mappings
from text.english import g2p as g2p_en ph_to_char = []
except Exception:
g2p_en = None
words = norm_text_seg.split()
ph_to_word = [] ph_to_word = []
if g2p_en: word_tokens = []
for w_idx, w in enumerate(words): norm_text_agg = []
phs_w = g2p_en(w) import re as _re
ph_to_word += [w_idx] * len(phs_w) for seg_text, seg_lang in zip(textlist, langlist):
if ph_spans and ph_to_word: seg_phones, seg_word2ph, seg_norm = clean_text_inf(seg_text, seg_lang, _version)
for span in ph_spans: norm_text_agg.append(seg_norm)
ph_idx = span["phoneme_id"] if seg_lang in {"zh", "yue", "ja"} and seg_word2ph:
if 0 <= ph_idx < len(ph_to_word): # char-based
wi = ph_to_word[ph_idx] char_base_idx = len("".join(norm_text_agg[:-1]))
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi: for ch_idx, cnt in enumerate(seg_word2ph):
word_spans.append({ global_char_idx = char_base_idx + ch_idx
"word_index": wi, ph_to_char += [global_char_idx] * cnt
"word": words[wi] if wi < len(words) else "", ph_to_word += [-1] * len(seg_phones)
"start_s": span["start_s"], elif seg_lang in {"en", "ko"} and seg_word2ph:
"end_s": span["end_s"], # word-based
}) tokens_seg = [t for t in _re.findall(r"\S+", seg_norm) if not all((c in punctuation) for c in t)]
base_word_idx = len(word_tokens)
t_idx = 0
for cnt in seg_word2ph:
if t_idx < len(tokens_seg):
ph_to_word += [base_word_idx + t_idx] * cnt
t_idx += 1
ph_to_char += [-1] * len(seg_phones)
word_tokens.extend(tokens_seg)
else:
# unknown mapping; fill with -1
ph_to_char += [-1] * len(seg_phones)
ph_to_word += [-1] * len(seg_phones)
norm_text_agg = "".join(norm_text_agg)
return norm_text_agg, ph_to_char, ph_to_word, word_tokens
norm_text_seg, ph_to_char_map, ph_to_word_map, word_tokens = _build_mixed_mappings(text, text_language, version)
# char spans (for zh/yue/ja parts only)
char_spans = []
if ph_spans and ph_to_char_map and len(ph_to_char_map) >= 1:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_char_map):
ci = ph_to_char_map[ph_idx]
if ci == -1:
continue
if len(char_spans) == 0 or char_spans[-1]["char_index"] != ci:
char_spans.append({
"char_index": ci,
"char": norm_text_seg[ci] if ci < len(norm_text_seg) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
char_spans[-1]["end_s"] = span["end_s"]
# post-merge and remap for chars
if char_spans:
groups = {}
for cs in char_spans:
groups.setdefault(cs["char_index"], []).append(cs)
merged = []
gap_merge_s = 0.08
min_dur_s = max(0.015, frame_time)
for ci, lst in groups.items():
lst = sorted(lst, key=lambda x: x["start_s"])
cur = None
for it in lst:
if cur is None:
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
else:
if it["start_s"] - cur["end_s"] <= gap_merge_s:
if it["end_s"] > cur["end_s"]:
cur["end_s"] = it["end_s"]
else: else:
word_spans[-1]["end_s"] = span["end_s"] if cur["end_s"] - cur["start_s"] >= min_dur_s:
merged.append(cur)
cur = {"char_index": ci, "char": it.get("char", ""), "start_s": it["start_s"], "end_s": it["end_s"]}
if cur is not None and (cur["end_s"] - cur["start_s"]) >= min_dur_s:
merged.append(cur)
char_spans = sorted(merged, key=lambda x: x["start_s"])
def _build_norm_to_orig_map(orig, norm):
all_punc_local = set(punctuation).union(set(splits))
mapping = [-1] * len(norm)
o = 0
n = 0
while n < len(norm) and o < len(orig):
if norm[n] == orig[o]:
mapping[n] = o
n += 1
o += 1
else:
if orig[o].isspace() or orig[o] in all_punc_local:
o += 1
elif norm[n].isspace() or norm[n] in all_punc_local:
n += 1
else:
n += 1
return mapping
norm2orig = _build_norm_to_orig_map(text, norm_text_seg)
all_punc_local = set(punctuation).union(set(splits))
remapped = []
for cs in char_spans:
ci = cs["char_index"]
ch_norm = cs.get("char", "")
oi = norm2orig[ci] if ci < len(norm2orig) else -1
if oi != -1:
cs["char"] = text[oi]
remapped.append(cs)
else:
if ch_norm and (ch_norm in all_punc_local or ch_norm.isspace()):
continue
remapped.append(cs)
char_spans = remapped
# word spans (for en/ko parts only)
word_spans = []
if ph_spans and ph_to_word_map and len(ph_to_word_map) >= 1:
for span in ph_spans:
ph_idx = span["phoneme_id"]
if 0 <= ph_idx < len(ph_to_word_map):
wi = ph_to_word_map[ph_idx]
if wi == -1:
continue
if len(word_spans) == 0 or word_spans[-1]["word_index"] != wi:
word_spans.append({
"word_index": wi,
"word": word_tokens[wi] if wi < len(word_tokens) else "",
"start_s": span["start_s"],
"end_s": span["end_s"],
})
else:
word_spans[-1]["end_s"] = span["end_s"]
# add absolute offsets and record segment timing # add absolute offsets and record segment timing
audio_len_s = float(audio.shape[0]) / sr_hz audio_len_s = float(audio.shape[0]) / sr_hz
if 'ph_spans' in locals() and ph_spans: if 'ph_spans' in locals() and ph_spans:
@ -1185,31 +1282,24 @@ def get_tts_wav(
srt_lines = [] srt_lines = []
idx_counter = 1 idx_counter = 1
for rec in timestamps_all: for rec in timestamps_all:
# 优先按字分段 cs = rec.get("char_spans") or []
if rec.get("char_spans") and len(rec["char_spans"]): ws = rec.get("word_spans") or []
for c in rec["char_spans"]: entries = []
st = c["start_s"] # 统一时间轴:存在则合并后按开始时间排序输出
ed = c["end_s"] for c in cs:
txt = c.get("char", "") entries.append({"text": c.get("char", ""), "start": c["start_s"], "end": c["end_s"]})
for w in ws:
entries.append({"text": w.get("word", ""), "start": w["start_s"], "end": w["end_s"]})
if entries:
entries.sort(key=lambda x: x["start"])
for e in entries:
srt_lines.append(str(idx_counter)) srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}") srt_lines.append(f"{_fmt_srt_time(e['start'])} --> { _fmt_srt_time(e['end'])}")
srt_lines.append(txt) srt_lines.append(e["text"])
srt_lines.append("") srt_lines.append("")
idx_counter += 1 idx_counter += 1
continue continue
# 次选按词(英文) # 兜底:整段
if rec.get("word_spans") and len(rec["word_spans"]):
for w in rec["word_spans"]:
st = w["start_s"]
ed = w["end_s"]
txt = w.get("word", "")
srt_lines.append(str(idx_counter))
srt_lines.append(f"{_fmt_srt_time(st)} --> { _fmt_srt_time(ed)}")
srt_lines.append(txt)
srt_lines.append("")
idx_counter += 1
continue
# 最后按整段兜底
st = rec.get("segment_start_s") st = rec.get("segment_start_s")
ed = rec.get("segment_end_s") ed = rec.get("segment_end_s")
text_line = rec.get("text", "") text_line = rec.get("text", "")
@ -1228,8 +1318,17 @@ def get_tts_wav(
except Exception: except Exception:
srt_path = None srt_path = None
# Return audio, timestamps and SRT path for UI # Also write JSON timestamps to temp file for download
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path json_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w", encoding="utf-8") as fjson:
json.dump(timestamps_all, fjson, ensure_ascii=False, indent=2)
json_path = fjson.name
except Exception:
json_path = None
# Return audio, timestamps, SRT path and JSON path for UI
yield (opt_sr, (audio_opt * 32767).astype(np.int16)), timestamps_all, srt_path, json_path
def split(todo_text): def split(todo_text):
@ -1516,8 +1615,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Row(): with gr.Row():
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25) inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
output = gr.Audio(label=i18n("输出的语音"), scale=14) output = gr.Audio(label=i18n("输出的语音"), scale=14)
timestamps_box = gr.JSON(label=i18n("时间戳(音素/字/词)")) with gr.Accordion(i18n("时间戳(音素/字/词) - 点击展开"), open=False):
timestamps_box = gr.JSON(label=i18n("时间戳JSON预览"))
srt_file = gr.File(label=i18n("下载SRT字幕")) srt_file = gr.File(label=i18n("下载SRT字幕"))
json_file = gr.File(label=i18n("下载时间戳JSON"))
inference_button.click( inference_button.click(
get_tts_wav, get_tts_wav,
@ -1539,7 +1640,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
if_sr_Checkbox, if_sr_Checkbox,
pause_second_slider, pause_second_slider,
], ],
[output, timestamps_box, srt_file], [output, timestamps_box, srt_file, json_file],
) )
SoVITS_dropdown.change( SoVITS_dropdown.change(
change_sovits_weights, change_sovits_weights,

View File

@ -39,18 +39,24 @@ def clean_text(text, language, version=None):
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
else: else:
norm_text = text norm_text = text
if language == "zh" or language == "yue": ########## if language == "zh" or language == "yue": ##########
phones, word2ph = language_module.g2p(norm_text) phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph) assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph) assert len(norm_text) == len(word2ph)
elif language == "en":
phones = language_module.g2p(norm_text)
if len(phones) < 4:
phones = [","] + phones
word2ph = None
else: else:
phones = language_module.g2p(norm_text) # Try per-language word2ph helpers
word2ph = None if hasattr(language_module, "g2p_with_word2ph"):
try:
phones, word2ph = language_module.g2p_with_word2ph(norm_text, keep_punc=False)
except Exception:
phones = language_module.g2p(norm_text)
word2ph = None
else:
phones = language_module.g2p(norm_text)
word2ph = None
if language == "en" and len(phones) < 4:
phones = [","] + phones
phones = ["UNK" if ph not in symbols else ph for ph in phones] phones = ["UNK" if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text return phones, word2ph, norm_text

View File

@ -368,6 +368,32 @@ def g2p(text):
return replace_phs(phones) return replace_phs(phones)
def g2p_with_word2ph(text, keep_punc=False):
"""
Returns (phones, word2ph) for English by whitespace tokenization.
- Tokenize by spaces; for each token, call g2p(token)
- word2ph per token = max(1, num_valid_phones) if keep_punc else skip pure punctuations
"""
tokens = re.split(r"(\s+)", text)
phones_all = []
word2ph = []
punc_set = set(punctuation)
for tok in tokens:
if tok.strip() == "":
if keep_punc:
word2ph.append(1)
continue
if all((c in punc_set or c.isspace()) for c in tok):
if keep_punc:
word2ph.append(1)
continue
phs = g2p(tok)
phs_valid = [p for p in phs if p not in punc_set]
phones_all.extend(phs_valid)
word2ph.append(max(1, len(phs_valid)))
return phones_all, word2ph
if __name__ == "__main__": if __name__ == "__main__":
print(g2p("hello")) print(g2p("hello"))
print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture."))) print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture.")))

View File

@ -271,6 +271,30 @@ def g2p(norm_text, with_prosody=True):
return phones return phones
# Helper for alignment: build phones and word2ph by per-character g2p (ignoring prosody markers)
def g2p_with_word2ph(text, keep_punc=False):
"""
Returns (phones, word2ph)
- Per-character g2p; ignore prosody markers like '[', ']','^', '$', '#', '_'
- Punctuation counted as 1 if keep_punc else skipped
"""
norm_text = text_normalize(text)
phones_all = []
word2ph = []
prosody_markers = {'[', ']', '^', '$', '#', '_'}
punc_set = set(punctuation)
for ch in norm_text:
if ch.isspace() or ch in punc_set:
if keep_punc:
word2ph.append(1)
continue
phs = preprocess_jap(ch, with_prosody=True)
phs = [post_replace_ph(p) for p in phs if p not in prosody_markers and p not in punc_set]
phones_all.extend(phs)
word2ph.append(max(1, len(phs)))
return phones_all, word2ph
if __name__ == "__main__": if __name__ == "__main__":
phones = g2p("Hello.こんにちは今日もNiCe天気ですねtokyotowerに行きましょう") phones = g2p("Hello.こんにちは今日もNiCe天気ですねtokyotowerに行きましょう")
print(phones) print(phones)

View File

@ -1,337 +1,362 @@
# reference: https://github.com/ORI-Muchim/MB-iSTFT-VITS-Korean/blob/main/text/korean.py # reference: https://github.com/ORI-Muchim/MB-iSTFT-VITS-Korean/blob/main/text/korean.py
import re import re
from jamo import h2j, j2hcj from jamo import h2j, j2hcj
import ko_pron import ko_pron
from g2pk2 import G2p from g2pk2 import G2p
import importlib import importlib
import os import os
# 防止win下无法读取模型 # 防止win下无法读取模型
if os.name == "nt": if os.name == "nt":
class win_G2p(G2p): class win_G2p(G2p):
def check_mecab(self): def check_mecab(self):
super().check_mecab() super().check_mecab()
spam_spec = importlib.util.find_spec("eunjeon") spam_spec = importlib.util.find_spec("eunjeon")
non_found = spam_spec is None non_found = spam_spec is None
if non_found: if non_found:
print("you have to install eunjeon. install it...") print("you have to install eunjeon. install it...")
else: else:
installpath = spam_spec.submodule_search_locations[0] installpath = spam_spec.submodule_search_locations[0]
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
import sys import sys
from eunjeon import Mecab as _Mecab from eunjeon import Mecab as _Mecab
class Mecab(_Mecab): class Mecab(_Mecab):
def get_dicpath(installpath): def get_dicpath(installpath):
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)): if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
import shutil import shutil
python_dir = os.getcwd() python_dir = os.getcwd()
if installpath[: len(python_dir)].upper() == python_dir.upper(): if installpath[: len(python_dir)].upper() == python_dir.upper():
dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc") dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc")
else: else:
if not os.path.exists("TEMP"): if not os.path.exists("TEMP"):
os.mkdir("TEMP") os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ko")): if not os.path.exists(os.path.join("TEMP", "ko")):
os.mkdir(os.path.join("TEMP", "ko")) os.mkdir(os.path.join("TEMP", "ko"))
if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")): if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")):
shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict")) shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict"))
shutil.copytree( shutil.copytree(
os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict") os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict")
) )
dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc") dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc")
else: else:
dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc")) dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc"))
return dicpath return dicpath
def __init__(self, dicpath=get_dicpath(installpath)): def __init__(self, dicpath=get_dicpath(installpath)):
super().__init__(dicpath=dicpath) super().__init__(dicpath=dicpath)
sys.modules["eunjeon"].Mecab = Mecab sys.modules["eunjeon"].Mecab = Mecab
G2p = win_G2p G2p = win_G2p
from text.symbols2 import symbols from text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals. # This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = ( _korean_classifiers = (
"군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통" "군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통"
) )
# List of (hangul, hangul divided) pairs: # List of (hangul, hangul divided) pairs:
_hangul_divided = [ _hangul_divided = [
(re.compile("%s" % x[0]), x[1]) (re.compile("%s" % x[0]), x[1])
for x in [ for x in [
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule # ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
# ('ㄵ', 'ㄴㅈ'), # ('ㄵ', 'ㄴㅈ'),
# ('ㄶ', 'ㄴㅎ'), # ('ㄶ', 'ㄴㅎ'),
# ('ㄺ', 'ㄹㄱ'), # ('ㄺ', 'ㄹㄱ'),
# ('ㄻ', 'ㄹㅁ'), # ('ㄻ', 'ㄹㅁ'),
# ('ㄼ', 'ㄹㅂ'), # ('ㄼ', 'ㄹㅂ'),
# ('ㄽ', 'ㄹㅅ'), # ('ㄽ', 'ㄹㅅ'),
# ('ㄾ', 'ㄹㅌ'), # ('ㄾ', 'ㄹㅌ'),
# ('ㄿ', 'ㄹㅍ'), # ('ㄿ', 'ㄹㅍ'),
# ('ㅀ', 'ㄹㅎ'), # ('ㅀ', 'ㄹㅎ'),
# ('ㅄ', 'ㅂㅅ'), # ('ㅄ', 'ㅂㅅ'),
("", "ㅗㅏ"), ("", "ㅗㅏ"),
("", "ㅗㅐ"), ("", "ㅗㅐ"),
("", "ㅗㅣ"), ("", "ㅗㅣ"),
("", "ㅜㅓ"), ("", "ㅜㅓ"),
("", "ㅜㅔ"), ("", "ㅜㅔ"),
("", "ㅜㅣ"), ("", "ㅜㅣ"),
("", "ㅡㅣ"), ("", "ㅡㅣ"),
("", "ㅣㅏ"), ("", "ㅣㅏ"),
("", "ㅣㅐ"), ("", "ㅣㅐ"),
("", "ㅣㅓ"), ("", "ㅣㅓ"),
("", "ㅣㅔ"), ("", "ㅣㅔ"),
("", "ㅣㅗ"), ("", "ㅣㅗ"),
("", "ㅣㅜ"), ("", "ㅣㅜ"),
] ]
] ]
# List of (Latin alphabet, hangul) pairs: # List of (Latin alphabet, hangul) pairs:
_latin_to_hangul = [ _latin_to_hangul = [
(re.compile("%s" % x[0], re.IGNORECASE), x[1]) (re.compile("%s" % x[0], re.IGNORECASE), x[1])
for x in [ for x in [
("a", "에이"), ("a", "에이"),
("b", ""), ("b", ""),
("c", ""), ("c", ""),
("d", ""), ("d", ""),
("e", ""), ("e", ""),
("f", "에프"), ("f", "에프"),
("g", ""), ("g", ""),
("h", "에이치"), ("h", "에이치"),
("i", "아이"), ("i", "아이"),
("j", "제이"), ("j", "제이"),
("k", "케이"), ("k", "케이"),
("l", ""), ("l", ""),
("m", ""), ("m", ""),
("n", ""), ("n", ""),
("o", ""), ("o", ""),
("p", ""), ("p", ""),
("q", ""), ("q", ""),
("r", "아르"), ("r", "아르"),
("s", "에스"), ("s", "에스"),
("t", ""), ("t", ""),
("u", ""), ("u", ""),
("v", "브이"), ("v", "브이"),
("w", "더블유"), ("w", "더블유"),
("x", "엑스"), ("x", "엑스"),
("y", "와이"), ("y", "와이"),
("z", "제트"), ("z", "제트"),
] ]
] ]
# List of (ipa, lazy ipa) pairs: # List of (ipa, lazy ipa) pairs:
_ipa_to_lazy_ipa = [ _ipa_to_lazy_ipa = [
(re.compile("%s" % x[0], re.IGNORECASE), x[1]) (re.compile("%s" % x[0], re.IGNORECASE), x[1])
for x in [ for x in [
("t͡ɕ", "ʧ"), ("t͡ɕ", "ʧ"),
("d͡ʑ", "ʥ"), ("d͡ʑ", "ʥ"),
("ɲ", "n^"), ("ɲ", "n^"),
("ɕ", "ʃ"), ("ɕ", "ʃ"),
("ʷ", "w"), ("ʷ", "w"),
("ɭ", "l`"), ("ɭ", "l`"),
("ʎ", "ɾ"), ("ʎ", "ɾ"),
("ɣ", "ŋ"), ("ɣ", "ŋ"),
("ɰ", "ɯ"), ("ɰ", "ɯ"),
("ʝ", "j"), ("ʝ", "j"),
("ʌ", "ə"), ("ʌ", "ə"),
("ɡ", "g"), ("ɡ", "g"),
("\u031a", "#"), ("\u031a", "#"),
("\u0348", "="), ("\u0348", "="),
("\u031e", ""), ("\u031e", ""),
("\u0320", ""), ("\u0320", ""),
("\u0339", ""), ("\u0339", ""),
] ]
] ]
def fix_g2pk2_error(text): def fix_g2pk2_error(text):
new_text = "" new_text = ""
i = 0 i = 0
while i < len(text) - 4: while i < len(text) - 4:
if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "": if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "":
new_text += text[i : i + 3] + " " + "" new_text += text[i : i + 3] + " " + ""
i += 5 i += 5
else: else:
new_text += text[i] new_text += text[i]
i += 1 i += 1
new_text += text[i:] new_text += text[i:]
return new_text return new_text
def latin_to_hangul(text): def latin_to_hangul(text):
for regex, replacement in _latin_to_hangul: for regex, replacement in _latin_to_hangul:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def divide_hangul(text): def divide_hangul(text):
text = j2hcj(h2j(text)) text = j2hcj(h2j(text))
for regex, replacement in _hangul_divided: for regex, replacement in _hangul_divided:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def hangul_number(num, sino=True): def hangul_number(num, sino=True):
"""Reference https://github.com/Kyubyong/g2pK""" """Reference https://github.com/Kyubyong/g2pK"""
num = re.sub(",", "", num) num = re.sub(",", "", num)
if num == "0": if num == "0":
return "" return ""
if not sino and num == "20": if not sino and num == "20":
return "스무" return "스무"
digits = "123456789" digits = "123456789"
names = "일이삼사오육칠팔구" names = "일이삼사오육칠팔구"
digit2name = {d: n for d, n in zip(digits, names)} digit2name = {d: n for d, n in zip(digits, names)}
modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉" modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉"
decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔" decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔"
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
spelledout = [] spelledout = []
for i, digit in enumerate(num): for i, digit in enumerate(num):
i = len(num) - i - 1 i = len(num) - i - 1
if sino: if sino:
if i == 0: if i == 0:
name = digit2name.get(digit, "") name = digit2name.get(digit, "")
elif i == 1: elif i == 1:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
name = name.replace("일십", "") name = name.replace("일십", "")
else: else:
if i == 0: if i == 0:
name = digit2mod.get(digit, "") name = digit2mod.get(digit, "")
elif i == 1: elif i == 1:
name = digit2dec.get(digit, "") name = digit2dec.get(digit, "")
if digit == "0": if digit == "0":
if i % 4 == 0: if i % 4 == 0:
last_three = spelledout[-min(3, len(spelledout)) :] last_three = spelledout[-min(3, len(spelledout)) :]
if "".join(last_three) == "": if "".join(last_three) == "":
spelledout.append("") spelledout.append("")
continue continue
else: else:
spelledout.append("") spelledout.append("")
continue continue
if i == 2: if i == 2:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
name = name.replace("일백", "") name = name.replace("일백", "")
elif i == 3: elif i == 3:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
name = name.replace("일천", "") name = name.replace("일천", "")
elif i == 4: elif i == 4:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
name = name.replace("일만", "") name = name.replace("일만", "")
elif i == 5: elif i == 5:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
name = name.replace("일십", "") name = name.replace("일십", "")
elif i == 6: elif i == 6:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
name = name.replace("일백", "") name = name.replace("일백", "")
elif i == 7: elif i == 7:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
name = name.replace("일천", "") name = name.replace("일천", "")
elif i == 8: elif i == 8:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
elif i == 9: elif i == 9:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
elif i == 10: elif i == 10:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
elif i == 11: elif i == 11:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
elif i == 12: elif i == 12:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
elif i == 13: elif i == 13:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
elif i == 14: elif i == 14:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
elif i == 15: elif i == 15:
name = digit2name.get(digit, "") + "" name = digit2name.get(digit, "") + ""
spelledout.append(name) spelledout.append(name)
return "".join(elem for elem in spelledout) return "".join(elem for elem in spelledout)
def number_to_hangul(text): def number_to_hangul(text):
"""Reference https://github.com/Kyubyong/g2pK""" """Reference https://github.com/Kyubyong/g2pK"""
tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text)) tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text))
for token in tokens: for token in tokens:
num, classifier = token num, classifier = token
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
spelledout = hangul_number(num, sino=False) spelledout = hangul_number(num, sino=False)
else: else:
spelledout = hangul_number(num, sino=True) spelledout = hangul_number(num, sino=True)
text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}") text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}")
# digit by digit for remaining digits # digit by digit for remaining digits
digits = "0123456789" digits = "0123456789"
names = "영일이삼사오육칠팔구" names = "영일이삼사오육칠팔구"
for d, n in zip(digits, names): for d, n in zip(digits, names):
text = text.replace(d, n) text = text.replace(d, n)
return text return text
def korean_to_lazy_ipa(text): def korean_to_lazy_ipa(text):
text = latin_to_hangul(text) text = latin_to_hangul(text)
text = number_to_hangul(text) text = number_to_hangul(text)
text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text) text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text)
for regex, replacement in _ipa_to_lazy_ipa: for regex, replacement in _ipa_to_lazy_ipa:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
_g2p = G2p() _g2p = G2p()
def korean_to_ipa(text): def korean_to_ipa(text):
text = latin_to_hangul(text) text = latin_to_hangul(text)
text = number_to_hangul(text) text = number_to_hangul(text)
text = _g2p(text) text = _g2p(text)
text = fix_g2pk2_error(text) text = fix_g2pk2_error(text)
text = korean_to_lazy_ipa(text) text = korean_to_lazy_ipa(text)
return text.replace("ʧ", "").replace("ʥ", "") return text.replace("ʧ", "").replace("ʥ", "")
def post_replace_ph(ph): def post_replace_ph(ph):
rep_map = { rep_map = {
"": ",", "": ",",
"": ",", "": ",",
"": ",", "": ",",
"": ".", "": ".",
"": "!", "": "!",
"": "?", "": "?",
"\n": ".", "\n": ".",
"·": ",", "·": ",",
"": ",", "": ",",
"...": "", "...": "",
" ": "", " ": "",
} }
if ph in rep_map.keys(): if ph in rep_map.keys():
ph = rep_map[ph] ph = rep_map[ph]
if ph in symbols: if ph in symbols:
return ph return ph
if ph not in symbols: if ph not in symbols:
ph = "" ph = ""
return ph return ph
def g2p(text): def g2p(text):
text = latin_to_hangul(text) text = latin_to_hangul(text)
text = _g2p(text) text = _g2p(text)
text = divide_hangul(text) text = divide_hangul(text)
text = fix_g2pk2_error(text) text = fix_g2pk2_error(text)
text = re.sub(r"([\u3131-\u3163])$", r"\1.", text) text = re.sub(r"([\u3131-\u3163])$", r"\1.", text)
# text = "".join([post_replace_ph(i) for i in text]) # text = "".join([post_replace_ph(i) for i in text])
text = [post_replace_ph(i) for i in text] text = [post_replace_ph(i) for i in text]
return text return text
if __name__ == "__main__": # Helper for alignment: build phones and word2ph by space-tokenizing Korean text
text = "안녕하세요" def g2p_with_word2ph(text, keep_punc=False):
print(g2p(text)) """
Returns (phones, word2ph)
- Tokenize by spaces; for each non-space token, call g2p(token)
- Exclude separator/placeholders from phones: '.', '', '' and common punctuations
- word2ph per token = max(1, num_valid_phones) if keep_punc else skip spaces entirely
"""
# prepare
tokens = re.split(r"(\s+)", text)
phones_all = []
word2ph = []
punc_set = {",", ".", "!", "?", "", "", "", ":", ""}
for tok in tokens:
if tok.strip() == "":
if keep_punc:
word2ph.append(1)
continue
phs = g2p(tok)
# filter non-phonetic markers
phs_valid = [p for p in phs if p not in {'.', '', ''} and p not in punc_set]
phones_all.extend(phs_valid)
word2ph.append(max(1, len(phs_valid)))
return phones_all, word2ph
if __name__ == "__main__":
text = "안녕하세요"
print(g2p(text))

3494
api.py

File diff suppressed because it is too large Load Diff