diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py deleted file mode 100644 index 42a8e2b1..00000000 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py +++ /dev/null @@ -1,233 +0,0 @@ -import os -import sys - -from tqdm import tqdm - -now_dir = os.getcwd() -sys.path.append(now_dir) - -import re -from text.LangSegmenter import LangSegmenter -from typing import Dict, List, Tuple -from text.cleaner import clean_text -from text import cleaned_text_to_sequence -from transformers import PreTrainedTokenizerFast -from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method -import onnxruntime as ort -import numpy as np - -from tools.i18n.i18n import I18nAuto, scan_language_list - -language = os.environ.get("language", "Auto") -language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language -i18n = I18nAuto(language=language) -punctuation = set(["!", "?", "…", ",", ".", "-"]) - - -def get_first(text: str) -> str: - pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" - text = re.split(pattern, text)[0].strip() - return text - - -def merge_short_text_in_array(texts: str, threshold: int) -> list: - if (len(texts)) < 2: - return texts - result = [] - text = "" - for ele in texts: - text += ele - if len(text) >= threshold: - result.append(text) - text = "" - if len(text) > 0: - if len(result) == 0: - result.append(text) - else: - result[len(result) - 1] += text - return result - - -class TextPreprocessorOnnx: - def __init__(self, onnx_package: str): - self.bert_model = ort.InferenceSession(os.path.join(onnx_package, "chinese-roberta-wwm-ext-large.onnx")) - self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=os.path.join(onnx_package, "tokenizer.json")) - - def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: - print(f"############ {i18n('切分文本')} ############") - text = self.replace_consecutive_punctuation(text) - texts = self.pre_seg_text(text, lang, text_split_method) - result = [] - print(f"############ {i18n('提取文本Bert特征')} ############") - for text in tqdm(texts): - phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) - if phones is None or norm_text == "": - continue - res = { - "phones": phones, - "bert_features": bert_features, - "norm_text": norm_text, - } - result.append(res) - return result - - def pre_seg_text(self, text: str, lang: str, text_split_method: str): - text = text.strip("\n") - if len(text) == 0: - return [] - if text[0] not in splits and len(get_first(text)) < 4: - text = "。" + text if lang != "en" else "." + text - print(i18n("实际输入的目标文本:")) - print(text) - - seg_method = get_seg_method(text_split_method) - text = seg_method(text) - - while "\n\n" in text: - text = text.replace("\n\n", "\n") - - _texts = text.split("\n") - _texts = self.filter_text(_texts) - _texts = merge_short_text_in_array(_texts, 5) - texts = [] - - for text in _texts: - # 解决输入目标文本的空行导致报错的问题 - if len(text.strip()) == 0: - continue - if not re.sub("\W+", "", text): - # 检测一下,如果是纯符号,就跳过。 - continue - if text[-1] not in splits: - text += "。" if lang != "en" else "." - - # 解决句子过长导致Bert报错的问题 - if len(text) > 510: - texts.extend(split_big_text(text)) - else: - texts.append(text) - - print(i18n("实际输入的目标文本(切句后):")) - print(texts) - return texts - - def segment_and_extract_feature_for_text( - self, text: str, language: str, version: str = "v1" - ) -> Tuple[list, np.ndarray, str]: - return self.get_phones_and_bert(text, language, version) - - def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - # print(textlist) - # print(langlist) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = np.concatenate(bert_list, axis=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) - - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text, language, version, final=True) - - return phones, bert, norm_text - - def get_bert_feature(self, text: str, word2ph: list) -> np.ndarray: - inputs = self.tokenizer(text, return_tensors="np") - [res] = self.bert_model.run(None, { - "input_ids": inputs["input_ids"] - }) - assert len(word2ph) == len(text) - phone_level_feature = [] - for i in range(len(word2ph)): - repeat_feature = np.repeat(res[i:i+1], word2ph[i], axis=0) - phone_level_feature.append(repeat_feature) - phone_level_feature = np.concatenate(phone_level_feature, axis=0) - return phone_level_feature.T - - def clean_text_inf(self, text: str, language: str, version: str = "v2"): - language = language.replace("all_", "") - phones, word2ph, norm_text = clean_text(text, language, version) - phones = cleaned_text_to_sequence(phones, version) - return phones, word2ph, norm_text - - def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str): - language = language.replace("all_", "") - if language == "zh": - feature = self.get_bert_feature(norm_text, word2ph) - else: - feature = np.zeros( - (1024, len(phones)), - dtype=np.float32, - ) - - return feature - - def filter_text(self, texts): - _text = [] - if all(text in [None, " ", "\n", ""] for text in texts): - raise ValueError(i18n("请输入有效文本")) - for text in texts: - if text in [None, " ", ""]: - pass - else: - _text.append(text) - return _text - - def replace_consecutive_punctuation(self, text): - punctuations = "".join(re.escape(p) for p in punctuation) - pattern = f"([{punctuations}])([{punctuations}])+" - result = re.sub(pattern, r"\1", text) - return result \ No newline at end of file diff --git a/playground/export_hubert.py b/playground/export_hubert.py deleted file mode 100644 index 463c6c9b..00000000 --- a/playground/export_hubert.py +++ /dev/null @@ -1,178 +0,0 @@ -import os -import sys -import torch -import torchaudio -import onnxruntime as ort -import numpy as np -import argparse -from transformers import HubertModel, HubertConfig - - -class HubertONNXExporter: - """Export and test HuBERT model to ONNX format""" - - def __init__(self, model_path="GPT_SoVITS/pretrained_models/chinese-hubert-base", output_path="playground/hubert/chinese-hubert-base.onnx"): - self.model_path = model_path - self.onnx_path = output_path - self.model = None - self.config = None - - def setup_model(self): - """Configure and load the HuBERT model for ONNX export""" - # Configure for better ONNX compatibility - self.config = HubertConfig.from_pretrained(self.model_path) - self.config._attn_implementation = "eager" # Use standard attention - self.config.apply_spec_augment = False # Disable masking for inference - self.config.layerdrop = 0.0 # Disable layer dropout - - # Load the model - self.model = HubertModel.from_pretrained( - self.model_path, - config=self.config, - local_files_only=True - ) - self.model.eval() - - def export_to_onnx(self, dummy_length=16000): - """Export the model to ONNX format""" - if self.model is None: - raise ValueError("Model not initialized. Call setup_model() first.") - - # Create dummy input (1 second at 16kHz) - dummy_input = torch.rand(1, dummy_length, dtype=torch.float32) - 0.5 - - # Export to ONNX - torch.onnx.export( - self.model, - dummy_input, - self.onnx_path, - export_params=True, - opset_version=11, - do_constant_folding=True, - input_names=['audio16k'], - output_names=['last_hidden_state'], - dynamic_axes={ - 'audio16k': {0: 'batch_size', 1: 'sequence_length'}, - 'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'} - } - ) - print(f"[Success] Model exported to {self.onnx_path}") - - def test_onnx_export_exists(self): - """Test that the ONNX model file was created""" - if os.path.exists(self.onnx_path): - print(f"[Success] ONNX model file exists at {self.onnx_path}") - return True - else: - print(f"[Error] ONNX model not found at {self.onnx_path}") - return False - - def _load_and_preprocess_audio(self, audio_path, max_length=160000): - """Load and preprocess audio file""" - waveform, sample_rate = torchaudio.load(audio_path) - - # Resample to 16kHz if needed - if sample_rate != 16000: - resampler = torchaudio.transforms.Resample(sample_rate, 16000) - waveform = resampler(waveform) - - # Take first channel - if waveform.shape[0] > 1: - waveform = waveform[0:1] - - # Limit length for testing (10 seconds at 16kHz) - if waveform.shape[1] > max_length: - waveform = waveform[:, :max_length] - - # make a zero tensor that has length 3200*0.3 - zero_tensor = torch.zeros((1, 9600), dtype=torch.float32) - - print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape) - - # concate zero_tensor with waveform - waveform = torch.cat([waveform, zero_tensor], dim=1) - - return waveform - - def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"): - """Test that ONNX model outputs match PyTorch model outputs""" - if not os.path.exists(audio_path): - print(f"[Skip] Test audio file not found at {audio_path}") - return False - - if self.model is None: - raise ValueError("Model not initialized. Call setup_model() first.") - - # Load and preprocess audio - waveform = self._load_and_preprocess_audio(audio_path) - - # PyTorch inference - with torch.no_grad(): - torch_output = self.model(waveform) - torch_hidden_states = torch_output.last_hidden_state - - # ONNX inference - ort_session = ort.InferenceSession(self.onnx_path) - input_values = waveform.numpy().astype(np.float32) - ort_inputs = {ort_session.get_inputs()[0].name: input_values} - ort_outputs = ort_session.run(None, ort_inputs) - onnx_hidden_states = ort_outputs[0] - - # Compare outputs - torch_numpy = torch_hidden_states.numpy() - diff = np.abs(torch_numpy - onnx_hidden_states).mean() - - success = diff <= 1e-5 - status = "[Success]" if success else "[Fail]" - - print(f"{status} ONNX vs PyTorch comparison") - print(f" > mean_difference={diff}") - print(f" > torch_shape={torch_numpy.shape}") - print(f" > onnx_shape={onnx_hidden_states.shape}") - - return success - - def run_full_export_and_test(self): - """Run the complete export and testing pipeline""" - print("Starting HuBERT ONNX export and testing...") - - # Create output directory if it doesn't exist - os.makedirs(os.path.dirname(self.onnx_path), exist_ok=True) - - # Setup model - self.setup_model() - - # Export to ONNX - self.export_to_onnx() - - # Test export - self.test_onnx_export_exists() - self.test_torch_vs_onnx() - - print("Export and testing complete!") - - -def main(): - """Main execution function""" - parser = argparse.ArgumentParser(description="Export HuBERT model to ONNX format") - parser.add_argument( - "--model_path", - type=str, - default="GPT_SoVITS/pretrained_models/chinese-hubert-base", - help="Path to the HuBERT model directory (default: GPT_SoVITS/pretrained_models/chinese-hubert-base)" - ) - parser.add_argument( - "--output_path", - type=str, - default="playground/hubert/chinese-hubert-base.onnx", - help="Output path for the ONNX model (default: playground/hubert/chinese-hubert-base.onnx)" - ) - - args = parser.parse_args() - - exporter = HubertONNXExporter(model_path=args.model_path, output_path=args.output_path) - exporter.run_full_export_and_test() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/playground/freerun.py b/playground/freerun.py deleted file mode 100644 index 9cfefbeb..00000000 --- a/playground/freerun.py +++ /dev/null @@ -1,146 +0,0 @@ -import onnxruntime as ort -import numpy as np -import onnx -from tqdm import tqdm -import torchaudio -import torch -from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx - - -MODEL_PATH = "onnx/v1_export/v1" - -def audio_postprocess( - audios, - fragment_interval: float = 0.3, -): - zero_wav = np.zeros((int(32000 * fragment_interval),)).astype(np.float32) - for i, audio in enumerate(audios): - max_audio = np.abs(audio).max() # 简单防止16bit爆音 - if max_audio > 1: - audio /= max_audio - audio = np.concatenate([audio, zero_wav], axis=0) - audios[i] = audio - - audio = np.concatenate(audios, axis=0) - - # audio = (audio * 32768).astype(np.int16) - - audio_tensor = torch.from_numpy(audio).unsqueeze(0) - - torchaudio.save('playground/output.wav', audio_tensor, 32000) - - return audio - -def load_audio(audio_path): - """Load and preprocess audio file to 32k""" - waveform, sample_rate = torchaudio.load(audio_path) - - # Resample to 32kHz if needed - if sample_rate != 32000: - resampler = torchaudio.transforms.Resample(sample_rate, 32000) - waveform = resampler(waveform) - - # Take first channel - if waveform.shape[0] > 1: - waveform = waveform[0:1] - - return waveform - -def audio_preprocess(audio_path): - """Get HuBERT features for the audio file""" - waveform = load_audio(audio_path) - ort_session = ort.InferenceSession(MODEL_PATH + "_export_audio_preprocess.onnx") - ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()} - [hubert_feature, spectrum, sv_emb] = ort_session.run(None, ort_inputs) - return hubert_feature, spectrum, sv_emb - -def preprocess_text(text:str): - preprocessor = TextPreprocessorOnnx("playground/bert") - [phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v2') - phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0) - return phones, bert_features.T.astype(np.float32) - - -# input_phones_saved = np.load("playground/ref/input_phones.npy") -# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32) -[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追。") - - -# ref_phones = np.load("playground/ref/ref_phones.npy") -# ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32) -[ref_phones, ref_bert] = preprocess_text("近日江苏苏州荷花市集开张热闹与浪漫交织") - - -[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav") - -# audio_prompt_hubert_saved = np.load("playground/ref/audio_prompt_hubert.npy").astype(np.float32) - -top_k = np.array([15], dtype=np.int64) -top_p = np.array([1.0], dtype=np.float32) -repetition_penalty = np.array([1.0], dtype=np.float32) -temperature = np.array([1.0], dtype=np.float32) - -t2s_init_stage = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_stage.onnx") -# t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx") - -[x, prompts, init_k, init_v, x_seq_len, y_seq_len] = t2s_init_stage.run(None, { - "input_text_phones": input_phones, - "input_text_bert": input_bert, - "ref_text_phones": ref_phones, - "ref_text_bert": ref_bert, - "hubert_ssl_content": audio_prompt_hubert, -}) -empty_tensor = np.empty((1,0,512)).astype(np.float32) - -t2s_stage_decoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_stage_decoder.onnx") -y, k, v, y_emb, logits, samples = t2s_stage_decoder.run(None, { - "ix": x, - "iy": prompts, - "ik": init_k, - "iv": init_v, - "iy_emb": empty_tensor, - "top_k": top_k, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "temperature": temperature, - "if_init_step": np.array([1]).astype(np.int64), - "x_seq_len": np.array([x_seq_len]).astype(np.int64), - "y_seq_len": np.array([y_seq_len]).astype(np.int64) -}) - -for idx in tqdm(range(1, 1500)): - k = np.pad(k, ((0,0), (0,1), (0,0), (0,0))) - v = np.pad(v, ((0,0), (0,1), (0,0), (0,0))) - y_seq_len = np.array([y.shape[1]]).astype(np.int64) - # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] - [y, k, v, y_emb, logits, samples] = t2s_stage_decoder.run(None, { - "ix": empty_tensor, - "iy": y, - "ik": k, - "iv": v, - "iy_emb": y_emb, - "top_k": top_k, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "temperature": temperature, - "if_init_step": np.array([0]).astype(np.int64), - "x_seq_len": np.array([x_seq_len]).astype(np.int64), - "y_seq_len": y_seq_len - }) - if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token - break -y = y[:,:-1] - - -pred_semantic = np.expand_dims(y[:, -idx:], axis=0) - -vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") - -[audio] = vtis.run(None, { - "input_text_phones": input_phones, - "pred_semantic": pred_semantic, - "spectrum": spectrum.astype(np.float32), - # "sv_emb": sv_emb.astype(np.float32) -}) - -audio_postprocess([audio]) diff --git a/playground/ref/audio.wav b/playground/ref/audio.wav deleted file mode 100644 index 78320d71..00000000 Binary files a/playground/ref/audio.wav and /dev/null differ