from imghdr import tests import os import sys import threading from tqdm import tqdm now_dir = os.getcwd() sys.path.append(now_dir) import re import torch from text.LangSegmenter import LangSegmenter from text import chinese from typing import Dict, List, Tuple from text.cleaner import clean_text from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from export_torch_script_v3 import extract_bert_features from tools.i18n.i18n import I18nAuto, scan_language_list language = os.environ.get("language", "Auto") language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language i18n = I18nAuto(language=language) punctuation = set(["!", "?", "…", ",", ".", "-"]) def get_first(text: str) -> str: pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" text = re.split(pattern, text)[0].strip() return text def merge_short_text_in_array(texts: str, threshold: int) -> list: if (len(texts)) < 2: return texts result = [] text = "" for ele in texts: text += ele if len(text) >= threshold: result.append(text) text = "" if len(text) > 0: if len(result) == 0: result.append(text) else: result[len(result) - 1] += text return result class TextPreprocessor: def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device): self.bert_model = bert_model self.tokenizer = tokenizer self.device = device self.bert_lock = threading.RLock() def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) texts = self.pre_seg_text(text, lang, text_split_method) result = [] # text_batch = [] print(f"############ {i18n('提取文本Bert特征')} ############") text_batch = [] for text in texts: if text.strip(): # 忽略空句子 text_batch.append(text) if not text_batch: return [] phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version) for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts): if phones is None or norm_text == "": continue res = { "phones": phones, "bert_features": bert_features, "norm_text": norm_text, } result.append(res) return result # for text in texts: # if text.strip(): # 忽略空句子 # text_batch.append(text) # phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version) # for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts): # if phones is None or norm_text == "": # continue # res = { # "phones": phones, # "bert_features": bert_features, # "norm_text": norm_text, # } # result.append(res) # return result # for text in tqdm(texts): # phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) # if phones is None or norm_text == "": # continue # res = { # "phones": phones, # "bert_features": bert_features, # "norm_text": norm_text, # } # result.append(res) # return result @torch.jit.script def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor: """ 将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数) Args: res: [N_words, hidden_dim] word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素) Returns: [sum(word2ph), hidden_dim] 的 phone 级别特征 """ phone_level_feature = [] for i in range(word2ph.shape[0]): repeat_feature = res[i].repeat(word2ph[i].item(), 1) phone_level_feature.append(repeat_feature) return torch.cat(phone_level_feature, dim=0) def pre_seg_text(self, text: str, lang: str, text_split_method: str): text = text.strip("\n") if len(text) == 0: return [] if text[0] not in splits and len(get_first(text)) < 4: text = "。" + text if lang != "en" else "." + text print(i18n("实际输入的目标文本:")) print(text) seg_method = get_seg_method(text_split_method) text = seg_method(text) while "\n\n" in text: text = text.replace("\n\n", "\n") _texts = text.split("\n") _texts = self.filter_text(_texts) _texts = merge_short_text_in_array(_texts, 5) texts = [] for text in _texts: # 解决输入目标文本的空行导致报错的问题 if len(text.strip()) == 0: continue if not re.sub("\W+", "", text): # 检测一下,如果是纯符号,就跳过。 continue if text[-1] not in splits: text += "。" if lang != "en" else "." # 解决句子过长导致Bert报错的问题 if len(text) > 510: texts.extend(split_big_text(text)) else: texts.append(text) print(i18n("实际输入的目标文本(切句后):")) print(texts) return texts def segment_and_extract_feature_for_text( self, text: str, language: str, version: str = "v1" ) -> Tuple[list, torch.Tensor, str]: return self.get_phones_and_bert(text, language, version) def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): with self.bert_lock: if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: # language = language.replace("all_","") formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") if language == "all_zh": if re.search(r"[A-Za-z]", formattext): formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) return self.get_phones_and_bert(formattext, "zh", version) else: phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) bert = self.get_bert_feature(norm_text, word2ph).to(self.device) elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) return self.get_phones_and_bert(formattext, "yue", version) else: phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) bert = torch.zeros( (1024, len(phones)), dtype=torch.float32, ).to(self.device) elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: textlist = [] langlist = [] if language == "auto": for tmp in LangSegmenter.getTexts(text): langlist.append(tmp["lang"]) textlist.append(tmp["text"]) elif language == "auto_yue": for tmp in LangSegmenter.getTexts(text): if tmp["lang"] == "zh": tmp["lang"] = "yue" langlist.append(tmp["lang"]) textlist.append(tmp["text"]) else: for tmp in LangSegmenter.getTexts(text): if tmp["lang"] == "en": langlist.append(tmp["lang"]) else: # 因无法区别中日韩文汉字,以用户输入为准 langlist.append(language) textlist.append(tmp["text"]) # print(textlist) # print(langlist) phones_list = [] bert_list = [] norm_text_list = [] for i in range(len(textlist)): lang = langlist[i] phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) bert = self.get_bert_inf(phones, word2ph, norm_text, lang) phones_list.append(phones) norm_text_list.append(norm_text) bert_list.append(bert) bert = torch.cat(bert_list, dim=1) phones = sum(phones_list, []) norm_text = "".join(norm_text_list) if not final and len(phones) < 6: return self.get_phones_and_bert("." + text, language, version, final=True) return phones, bert, norm_text def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: with torch.no_grad(): inputs = self.tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(self.device) res = self.bert_model(**inputs, output_hidden_states=True) # # 优化:保留在GPU处理直到需要时再转CPU res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用 assert len(word2ph) == len(text) # 向量化优化:使用repeat_interleave替代循环 word2ph_tensor = torch.tensor(word2ph, device=res.device) indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor) phone_level_feature = res[indices] # 仅在需要时转CPU phone_level_feature = phone_level_feature.cpu() return phone_level_feature.T def clean_text_inf(self, text: str, language: str, version: str = "v2"): language = language.replace("all_", "") phones, word2ph, norm_text = clean_text(text, language, version) phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): language = language.replace("all_", "") if language == "zh": feature = self.get_bert_feature(norm_text, word2ph).to(self.device) else: feature = torch.zeros( (1024, len(phones)), dtype=torch.float32, ).to(self.device) return feature def filter_text(self, texts): _text = [] if all(text in [None, " ", "\n", ""] for text in texts): raise ValueError(i18n("请输入有效文本")) for text in texts: if text in [None, " ", ""]: pass else: _text.append(text) return _text def replace_consecutive_punctuation(self, text): punctuations = "".join(re.escape(p) for p in punctuation) pattern = f"([{punctuations}])([{punctuations}])+" result = re.sub(pattern, r"\1", text) return result # def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: # print(f"############ {i18n('切分文本')} ############") # text = self.replace_consecutive_punctuation(text) # texts = self.pre_seg_text(text, lang, text_split_method) # result = [] # print(f"############ {i18n('提取文本Bert特征')} ############") # for text in tqdm(texts): # phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) # if phones is None or norm_text == "": # continue # res = { # "phones": phones, # "bert_features": bert_features, # "norm_text": norm_text, # } # result.append(res) # return result def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) texts = self.pre_seg_text(text, lang, text_split_method) result = [] print(f"############ {i18n('提取文本Bert特征')} ############") extract_bert_features(texts) for text in texts: # phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) if phones is None or norm_text == "": continue res = { "phones": phones, "bert_features": bert_features, "norm_text": norm_text, } result.append(res) return result def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str): phones_list = [] bert_list = [] norm_text_list = [] # 预处理文本,获取每句的 phones, word2ph, norm_text format_texts = [self.clean_text_inf(t, language, version) for t in texts] norm_texts = [x[2] for x in format_texts] word2phs = [x[1] for x in format_texts] # 批量送入 tokenizer inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.bert_model(**inputs, output_hidden_states=True) # 使用 last_hidden_state 是正确且高效的方式 hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim] for i in range(len(texts)): res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP] word2ph = word2phs[i] phone_level_feature = [] for j in range(len(word2ph)): if j >= res.shape[0]: print(f"警告:BERT输出不足,跳过第 {i} 句中第 {j} 个 token") continue phone_level_feature.append(res[j].repeat(word2ph[j], 1)) phone_level_feature = torch.cat(phone_level_feature, dim=0) bert_list.append(phone_level_feature.T) phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version)) norm_text_list.append(norm_texts[i]) return phones_list, bert_list, norm_text_list