diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py new file mode 100644 index 00000000..42a8e2b1 --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py @@ -0,0 +1,233 @@ +import os +import sys + +from tqdm import tqdm + +now_dir = os.getcwd() +sys.path.append(now_dir) + +import re +from text.LangSegmenter import LangSegmenter +from typing import Dict, List, Tuple +from text.cleaner import clean_text +from text import cleaned_text_to_sequence +from transformers import PreTrainedTokenizerFast +from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method +import onnxruntime as ort +import numpy as np + +from tools.i18n.i18n import I18nAuto, scan_language_list + +language = os.environ.get("language", "Auto") +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language +i18n = I18nAuto(language=language) +punctuation = set(["!", "?", "…", ",", ".", "-"]) + + +def get_first(text: str) -> str: + pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" + text = re.split(pattern, text)[0].strip() + return text + + +def merge_short_text_in_array(texts: str, threshold: int) -> list: + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if len(text) > 0: + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + + +class TextPreprocessorOnnx: + def __init__(self, onnx_package: str): + self.bert_model = ort.InferenceSession(os.path.join(onnx_package, "chinese-roberta-wwm-ext-large.onnx")) + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=os.path.join(onnx_package, "tokenizer.json")) + + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: + print(f"############ {i18n('切分文本')} ############") + text = self.replace_consecutive_punctuation(text) + texts = self.pre_seg_text(text, lang, text_split_method) + result = [] + print(f"############ {i18n('提取文本Bert特征')} ############") + for text in tqdm(texts): + phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) + if phones is None or norm_text == "": + continue + res = { + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + result.append(res) + return result + + def pre_seg_text(self, text: str, lang: str, text_split_method: str): + text = text.strip("\n") + if len(text) == 0: + return [] + if text[0] not in splits and len(get_first(text)) < 4: + text = "。" + text if lang != "en" else "." + text + print(i18n("实际输入的目标文本:")) + print(text) + + seg_method = get_seg_method(text_split_method) + text = seg_method(text) + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + + _texts = text.split("\n") + _texts = self.filter_text(_texts) + _texts = merge_short_text_in_array(_texts, 5) + texts = [] + + for text in _texts: + # 解决输入目标文本的空行导致报错的问题 + if len(text.strip()) == 0: + continue + if not re.sub("\W+", "", text): + # 检测一下,如果是纯符号,就跳过。 + continue + if text[-1] not in splits: + text += "。" if lang != "en" else "." + + # 解决句子过长导致Bert报错的问题 + if len(text) > 510: + texts.extend(split_big_text(text)) + else: + texts.append(text) + + print(i18n("实际输入的目标文本(切句后):")) + print(texts) + return texts + + def segment_and_extract_feature_for_text( + self, text: str, language: str, version: str = "v1" + ) -> Tuple[list, np.ndarray, str]: + return self.get_phones_and_bert(text, language, version) + + def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): + text = re.sub(r' {2,}', ' ', text) + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text,"zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text,"zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text,"ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text,"ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + # print(textlist) + # print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) + bert = self.get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = np.concatenate(bert_list, axis=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return self.get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert, norm_text + + def get_bert_feature(self, text: str, word2ph: list) -> np.ndarray: + inputs = self.tokenizer(text, return_tensors="np") + [res] = self.bert_model.run(None, { + "input_ids": inputs["input_ids"] + }) + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = np.repeat(res[i:i+1], word2ph[i], axis=0) + phone_level_feature.append(repeat_feature) + phone_level_feature = np.concatenate(phone_level_feature, axis=0) + return phone_level_feature.T + + def clean_text_inf(self, text: str, language: str, version: str = "v2"): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): + language = language.replace("all_", "") + if language == "zh": + feature = self.get_bert_feature(norm_text, word2ph) + else: + feature = np.zeros( + (1024, len(phones)), + dtype=np.float32, + ) + + return feature + + def filter_text(self, texts): + _text = [] + if all(text in [None, " ", "\n", ""] for text in texts): + raise ValueError(i18n("请输入有效文本")) + for text in texts: + if text in [None, " ", ""]: + pass + else: + _text.append(text) + return _text + + def replace_consecutive_punctuation(self, text): + punctuations = "".join(re.escape(p) for p in punctuation) + pattern = f"([{punctuations}])([{punctuations}])+" + result = re.sub(pattern, r"\1", text) + return result \ No newline at end of file diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 97b56b93..48d64a7e 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -402,15 +402,16 @@ if __name__ == "__main__": # version = "v2" # export(vits_path, gpt_path, exp_path, version) - # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" - # exp_path = "v2pro_export" - # version = "v2Pro" - gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" - exp_path = "v2proplus_export" - version = "v2ProPlus" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" + exp_path = "v2pro_export" + version = "v2Pro" export(vits_path, gpt_path, exp_path, version) + # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" + # exp_path = "v2proplus_export" + # version = "v2ProPlus" + # export(vits_path, gpt_path, exp_path, version) + diff --git a/playground/freerun.py b/playground/freerun.py index 2bad817a..684c664a 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -4,6 +4,8 @@ import onnx from tqdm import tqdm import torchaudio import torch +from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx + MODEL_PATH = "playground/v2proplus_export/v2proplus" @@ -30,7 +32,7 @@ def audio_postprocess( return audio def load_and_preprocess_audio(audio_path, max_length=160000): - """Load and preprocess audio file""" + """Load and preprocess audio file to 16k""" waveform, sample_rate = torchaudio.load(audio_path) # Resample to 16kHz if needed @@ -49,8 +51,6 @@ def load_and_preprocess_audio(audio_path, max_length=160000): # make a zero tensor that has length 3200*0.3 zero_tensor = torch.zeros((1, 9600), dtype=torch.float32) - print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape) - # concate zero_tensor with waveform waveform = torch.cat([waveform, zero_tensor], dim=1) @@ -64,13 +64,25 @@ def get_audio_hubert(audio_path): hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32) # transpose axis 1 and 2 with numpy hubert_feature = hubert_feature.transpose(0, 2, 1) - print("Hubert feature shape:", hubert_feature.shape) return hubert_feature -input_phones = np.load("playground/ref/input_phones.npy") -input_bert = np.load("playground/ref/input_bert.npy").T.astype(np.float32) -ref_phones = np.load("playground/ref/ref_phones.npy") -ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32) +def preprocess_text(text:str): + preprocessor = TextPreprocessorOnnx("playground/bert") + [phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v2') + phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0) + return phones, bert_features.T.astype(np.float32) + + +# input_phones_saved = np.load("playground/ref/input_phones.npy") +# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32) +[input_phones, input_bert] = preprocess_text("震撼视角,感受成都世运会,闭幕式烟花") + + +# ref_phones = np.load("playground/ref/ref_phones.npy") +# ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32) +[ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织") + + audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav") @@ -84,8 +96,6 @@ encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") "ssl_content": audio_prompt_hubert }) -print(x.shape, prompts.shape) - fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") @@ -124,6 +134,7 @@ if sample_rate != 32000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000) waveform = resampler(waveform) ref_audio = waveform.numpy().astype(np.float32) + vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") [audio] = vtis.run(None, { @@ -131,6 +142,5 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") "pred_semantic": pred_semantic, "ref_audio": ref_audio }) -print(audio.shape, audio.dtype, audio.min(), audio.max()) audio_postprocess([audio]) diff --git a/playground/output.wav b/playground/output.wav index d213ae43..ad75c328 100644 Binary files a/playground/output.wav and b/playground/output.wav differ diff --git a/requirements.txt b/requirements.txt index b09d2b79..7f24a84c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ rotary_embedding_torch ToJyutping g2pk2 ko_pron -opencc +opencc==1.1.6 python_mecab_ko; sys_platform != 'win32' fastapi[standard]>=0.115.2 x_transformers