GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
baicai-1145 17cb2e5acf Implement G2PW processing enhancements in TTS framework
Add support for G2PW processing in the TTS system by introducing new methods and classes for handling G2PW segments. Update PrepareCoordinator to manage G2PW worker threads and integrate G2PW profiling into the existing framework. Enhance text preprocessing to identify segments requiring G2PW and streamline the resolution of these segments. This update improves the overall performance and maintainability of the TTS system by optimizing the handling of Chinese text processing.
2026-03-12 23:04:39 +08:00

579 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
import sys
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
import re
import torch
from text.LangSegmenter import LangSegmenter
from text import chinese
from typing import Dict, List, Optional, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from TTS_infer_pack.prepare_bert_batch_worker import PrepareBertBatchWorker
from TTS_infer_pack.text_cpu_preprocess import preprocess_text_segments_payload
from tools.i18n.i18n import I18nAuto, scan_language_list
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
punctuation = set(["!", "?", "", ",", ".", "-"])
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if len(text) > 0:
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
class StageLimiter:
def __init__(self, slots: int):
self.slots = max(1, int(slots))
self.semaphore = threading.BoundedSemaphore(self.slots)
self.lock = threading.Lock()
self.inflight = 0
self.peak_inflight = 0
@contextmanager
def enter(self):
wait_start = time.perf_counter()
self.semaphore.acquire()
wait_ms = (time.perf_counter() - wait_start) * 1000.0
with self.lock:
self.inflight += 1
current_inflight = self.inflight
if current_inflight > self.peak_inflight:
self.peak_inflight = current_inflight
peak_inflight = self.peak_inflight
try:
yield {
"wait_ms": wait_ms,
"inflight": current_inflight,
"peak_inflight": peak_inflight,
"slots": self.slots,
}
finally:
with self.lock:
self.inflight = max(0, self.inflight - 1)
self.semaphore.release()
def snapshot(self) -> Dict[str, int]:
with self.lock:
return {
"slots": self.slots,
"inflight": self.inflight,
"peak_inflight": self.peak_inflight,
}
@dataclass
class PreparedTextSegment:
language: str
phones: List[int]
word2ph: Optional[List[int]]
norm_text: str
needs_g2pw: bool = False
class TextPreprocessor:
def __init__(
self,
bert_model: AutoModelForMaskedLM,
tokenizer: AutoTokenizer,
device: torch.device,
version: str = "v2",
bert_stage_limiter: StageLimiter | None = None,
bert_batch_worker: PrepareBertBatchWorker | None = None,
):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.version = str(version)
self.bert_stage_limiter = bert_stage_limiter
self.bert_batch_worker = bert_batch_worker
def snapshot(self) -> Dict[str, object]:
return {
"device": str(self.device),
"bert_stage_limiter": (
None if self.bert_stage_limiter is None else dict(self.bert_stage_limiter.snapshot())
),
"bert_batch_worker": None if self.bert_batch_worker is None else dict(self.bert_batch_worker.snapshot()),
}
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text == "":
continue
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n")
if len(text) == 0:
return []
if text[0] not in splits and len(get_first(text)) < 4:
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
_texts = text.split("\n")
_texts = self.filter_text(_texts)
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if len(text.strip()) == 0:
continue
if not re.sub(r"\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if text[-1] not in splits:
text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if len(text) > 510:
texts.extend(split_big_text(text))
else:
texts.append(text)
print(i18n("实际输入的目标文本(切句后):"))
print(texts)
return texts
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1", profile: Dict | None = None
) -> Tuple[list, torch.Tensor, str]:
prepared_segments = self.preprocess_text_segments(text, language, version)
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
def _split_text_by_language(self, text: str, language: str) -> Tuple[List[str], List[str]]:
textlist = []
langlist = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text, "zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text, "zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text, "ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text, "ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
same_group = (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
)
if same_group:
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
langlist.append(language)
textlist.append(tmp["text"])
return textlist, langlist
def get_phones_and_bert(
self, text: str, language: str, version: str, final: bool = False, profile: Dict | None = None
):
prepared_segments = self.preprocess_text_segments(text, language, version, final=final)
return self.build_phones_and_bert_from_segments(prepared_segments, profile=profile)
def preprocess_text_segments(
self,
text: str,
language: str,
version: str,
final: bool = False,
) -> List[PreparedTextSegment]:
payloads = preprocess_text_segments_payload(text, language, version, final=final)
return [
PreparedTextSegment(
language=str(payload["language"]),
phones=list(payload["phones"]),
word2ph=None if payload["word2ph"] is None else list(payload["word2ph"]),
norm_text=str(payload["norm_text"]),
needs_g2pw=bool(payload.get("needs_g2pw", False)),
)
for payload in payloads
]
def resolve_g2pw_segments(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> List[PreparedTextSegment]:
zh_indices = [index for index, segment in enumerate(prepared_segments) if bool(segment.needs_g2pw)]
if not zh_indices:
return prepared_segments
from text import chinese2
normalized_segments = [prepared_segments[index].norm_text for index in zh_indices]
resolved_segments, g2pw_profile = chinese2.g2p_segments(normalized_segments, return_profile=True)
self._accumulate_profile(profile, "g2pw_prepare_ms", g2pw_profile.get("g2pw_prepare_ms", 0.0))
self._accumulate_profile(profile, "g2pw_predict_ms", g2pw_profile.get("g2pw_predict_ms", 0.0))
self._accumulate_profile(profile, "g2pw_post_ms", g2pw_profile.get("g2pw_post_ms", 0.0))
self._accumulate_profile(profile, "g2pw_total_ms", g2pw_profile.get("g2pw_total_ms", 0.0))
self._accumulate_profile(profile, "g2pw_runtime_total_ms", g2pw_profile.get("g2pw_runtime_total_ms", 0.0))
self._accumulate_profile(profile, "g2pw_runtime_queue_wait_ms", g2pw_profile.get("g2pw_runtime_queue_wait_ms", 0.0))
self._accumulate_profile(
profile,
"g2pw_runtime_collect_wait_ms",
g2pw_profile.get("g2pw_runtime_collect_wait_ms", 0.0),
)
self._accumulate_profile(profile, "g2pw_runtime_run_ms", g2pw_profile.get("g2pw_runtime_run_ms", 0.0))
self._update_profile_peak(
profile,
"g2pw_runtime_batch_rows_peak",
g2pw_profile.get("g2pw_runtime_batch_rows", 0.0),
)
self._update_profile_peak(
profile,
"g2pw_runtime_batch_requests_peak",
g2pw_profile.get("g2pw_runtime_batch_requests", 0.0),
)
self._update_profile_peak(
profile,
"g2pw_runtime_pool_workers",
g2pw_profile.get("g2pw_runtime_pool_workers", 0.0),
)
for index, (phones, word2ph, norm_text) in zip(zh_indices, resolved_segments):
prepared_segments[index] = PreparedTextSegment(
language=prepared_segments[index].language,
phones=list(cleaned_text_to_sequence(phones, self.version)),
word2ph=None if word2ph is None else list(word2ph),
norm_text=str(norm_text),
needs_g2pw=False,
)
return prepared_segments
def build_phones_and_bert_from_segments(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
phones_list: List[List[int]] = []
bert_list: List[torch.Tensor] = []
norm_text_list: List[str] = []
for segment in prepared_segments:
bert = self.get_bert_inf(
segment.phones,
segment.word2ph,
segment.norm_text,
segment.language,
profile=profile,
)
phones_list.append(segment.phones)
norm_text_list.append(segment.norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
return phones, bert, norm_text
def _accumulate_profile(self, profile: Dict | None, key: str, value: float) -> None:
if profile is None:
return
profile[key] = float(profile.get(key, 0.0)) + float(value)
def _update_profile_peak(self, profile: Dict | None, key: str, value: float) -> None:
if profile is None:
return
profile[key] = float(max(float(profile.get(key, 0.0)), float(value)))
def _merge_bert_worker_profile(self, profile: Dict | None, worker_profile: Dict[str, float]) -> None:
self._accumulate_profile(profile, "bert_wait_ms", worker_profile.get("bert_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_admission_wait_ms", worker_profile.get("bert_admission_wait_ms", 0.0))
self._accumulate_profile(profile, "bert_queue_wait_ms", worker_profile.get("bert_queue_wait_ms", 0.0))
self._accumulate_profile(
profile,
"bert_batch_collect_wait_ms",
worker_profile.get("bert_batch_collect_wait_ms", 0.0),
)
self._accumulate_profile(profile, "bert_forward_ms", worker_profile.get("bert_forward_ms", 0.0))
self._accumulate_profile(profile, "bert_tokenize_ms", worker_profile.get("bert_tokenize_ms", 0.0))
self._accumulate_profile(profile, "bert_scatter_ms", worker_profile.get("bert_scatter_ms", 0.0))
self._accumulate_profile(profile, "bert_calls", worker_profile.get("bert_calls", 1.0))
self._update_profile_peak(profile, "bert_stage_inflight_peak", worker_profile.get("bert_stage_inflight_peak", 0.0))
self._update_profile_peak(profile, "bert_batch_size_peak", worker_profile.get("bert_batch_size", 0.0))
self._update_profile_peak(profile, "bert_batch_tokens_peak", worker_profile.get("bert_batch_tokens", 0.0))
self._update_profile_peak(
profile,
"bert_pending_depth_on_enqueue_peak",
worker_profile.get("bert_pending_depth_on_enqueue", 0.0),
)
self._update_profile_peak(
profile,
"bert_pending_depth_on_collect_peak",
worker_profile.get("bert_pending_depth_on_collect", 0.0),
)
self._update_profile_peak(profile, "bert_high_pressure_mode_peak", worker_profile.get("bert_high_pressure_mode", 0.0))
if profile is not None:
profile["bert_stage_slots"] = float(worker_profile.get("bert_stage_slots", 0.0))
profile["bert_batch_window_ms"] = float(worker_profile.get("bert_batch_window_ms", 0.0))
def get_bert_feature(self, text: str, word2ph: list, profile: Dict | None = None) -> torch.Tensor:
if self.bert_batch_worker is not None:
feature, worker_profile = self.bert_batch_worker.submit(text, word2ph)
self._merge_bert_worker_profile(profile, worker_profile)
return feature
limiter_stats = {"wait_ms": 0.0, "inflight": 1, "peak_inflight": 1, "slots": 0}
if self.bert_stage_limiter is None:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
else:
with self.bert_stage_limiter.enter() as limiter_stats:
forward_start = time.perf_counter()
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
forward_ms = (time.perf_counter() - forward_start) * 1000.0
self._accumulate_profile(profile, "bert_wait_ms", limiter_stats["wait_ms"])
self._accumulate_profile(profile, "bert_forward_ms", forward_ms)
self._accumulate_profile(profile, "bert_calls", 1.0)
self._update_profile_peak(profile, "bert_stage_inflight_peak", limiter_stats["peak_inflight"])
if profile is not None:
profile["bert_stage_slots"] = float(limiter_stats["slots"])
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(
self,
phones: list,
word2ph: Optional[list],
norm_text: str,
language: str,
profile: Dict | None = None,
):
language = language.replace("all_", "")
if language == "zh":
if word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
feature = self.get_bert_feature(norm_text, word2ph, profile=profile).to(self.device)
else:
feature = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
return feature
async def build_phones_and_bert_from_segments_async(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None = None,
) -> Tuple[list, torch.Tensor, str]:
prepared_segments = self.resolve_g2pw_segments(prepared_segments, profile=profile)
segment_jobs = self._build_async_segment_jobs(prepared_segments, profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
for segment_index, segment in enumerate(prepared_segments):
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
continue
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
pending_items.append(
(
segment_jobs["bert_list"],
segment_index,
profile,
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
)
)
if pending_items:
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
for (bert_list, bert_index, item_profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
self._merge_bert_worker_profile(item_profile, worker_profile)
bert_list[bert_index] = feature.to(self.device)
return self._finalize_async_segment_jobs(segment_jobs)
def _build_async_segment_jobs(
self,
prepared_segments: List[PreparedTextSegment],
profile: Dict | None,
) -> Dict[str, List]:
phones_list: List[List[int]] = []
bert_list: List[torch.Tensor | None] = []
norm_text_list: List[str] = []
for segment in prepared_segments:
phones_list.append(segment.phones)
norm_text_list.append(segment.norm_text)
segment_language = segment.language.replace("all_", "")
if segment_language == "zh" and self.bert_batch_worker is not None:
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
bert_list.append(None)
continue
bert_list.append(
self.get_bert_inf(
segment.phones,
segment.word2ph,
segment.norm_text,
segment.language,
profile=profile,
)
)
return {
"phones_list": phones_list,
"bert_list": bert_list,
"norm_text_list": norm_text_list,
}
@staticmethod
def _finalize_async_segment_jobs(segment_jobs: Dict[str, List]) -> Tuple[list, torch.Tensor, str]:
bert = torch.cat([feature for feature in segment_jobs["bert_list"] if feature is not None], dim=1)
phones = sum(segment_jobs["phones_list"], [])
norm_text = "".join(segment_jobs["norm_text_list"])
return phones, bert, norm_text
async def build_phones_and_bert_pair_from_segments_async(
self,
prompt_segments: List[PreparedTextSegment],
target_segments: List[PreparedTextSegment],
prompt_profile: Dict | None = None,
target_profile: Dict | None = None,
) -> Tuple[Tuple[list, torch.Tensor, str], Tuple[list, torch.Tensor, str]]:
prompt_segments = self.resolve_g2pw_segments(prompt_segments, profile=prompt_profile)
target_segments = self.resolve_g2pw_segments(target_segments, profile=target_profile)
prompt_jobs = self._build_async_segment_jobs(prompt_segments, prompt_profile)
target_jobs = self._build_async_segment_jobs(target_segments, target_profile)
pending_items: List[Tuple[List[torch.Tensor | None], int, Dict | None, asyncio.Future]] = []
for segment_jobs, prepared_segments, profile in (
(prompt_jobs, prompt_segments, prompt_profile),
(target_jobs, target_segments, target_profile),
):
for segment_index, segment in enumerate(prepared_segments):
if segment.language.replace("all_", "") != "zh" or self.bert_batch_worker is None:
continue
if segment.word2ph is None:
raise ValueError("中文文本缺少 word2ph无法提取 BERT 特征")
pending_items.append(
(
segment_jobs["bert_list"],
segment_index,
profile,
self.bert_batch_worker.submit_async(segment.norm_text, segment.word2ph),
)
)
if pending_items:
pending_results = await asyncio.gather(*[future for _, _, _, future in pending_items])
for (bert_list, bert_index, profile, _), (feature, worker_profile) in zip(pending_items, pending_results):
self._merge_bert_worker_profile(profile, worker_profile)
bert_list[bert_index] = feature.to(self.device)
return self._finalize_async_segment_jobs(prompt_jobs), self._finalize_async_segment_jobs(target_jobs)
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
pass
else:
_text.append(text)
return _text
def replace_consecutive_punctuation(self, text):
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result