feat:remove unneed for main

This commit is contained in:
zpeng11 2025-08-25 22:53:15 -04:00
parent 968ac4c264
commit fa84e262ae
4 changed files with 0 additions and 557 deletions

View File

@ -1,233 +0,0 @@
import os
import sys
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
import re
from text.LangSegmenter import LangSegmenter
from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from transformers import PreTrainedTokenizerFast
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
import onnxruntime as ort
import numpy as np
from tools.i18n.i18n import I18nAuto, scan_language_list
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
punctuation = set(["!", "?", "", ",", ".", "-"])
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if len(text) > 0:
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
class TextPreprocessorOnnx:
def __init__(self, onnx_package: str):
self.bert_model = ort.InferenceSession(os.path.join(onnx_package, "chinese-roberta-wwm-ext-large.onnx"))
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=os.path.join(onnx_package, "tokenizer.json"))
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text == "":
continue
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n")
if len(text) == 0:
return []
if text[0] not in splits and len(get_first(text)) < 4:
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
seg_method = get_seg_method(text_split_method)
text = seg_method(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
_texts = text.split("\n")
_texts = self.filter_text(_texts)
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if len(text.strip()) == 0:
continue
if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if text[-1] not in splits:
text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if len(text) > 510:
texts.extend(split_big_text(text))
else:
texts.append(text)
print(i18n("实际输入的目标文本(切句后):"))
print(texts)
return texts
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
) -> Tuple[list, np.ndarray, str]:
return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
text = re.sub(r' {2,}', ' ', text)
textlist = []
langlist = []
if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = np.concatenate(bert_list, axis=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True)
return phones, bert, norm_text
def get_bert_feature(self, text: str, word2ph: list) -> np.ndarray:
inputs = self.tokenizer(text, return_tensors="np")
[res] = self.bert_model.run(None, {
"input_ids": inputs["input_ids"]
})
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = np.repeat(res[i:i+1], word2ph[i], axis=0)
phone_level_feature.append(repeat_feature)
phone_level_feature = np.concatenate(phone_level_feature, axis=0)
return phone_level_feature.T
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
language = language.replace("all_", "")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph)
else:
feature = np.zeros(
(1024, len(phones)),
dtype=np.float32,
)
return feature
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
pass
else:
_text.append(text)
return _text
def replace_consecutive_punctuation(self, text):
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result

View File

@ -1,178 +0,0 @@
import os
import sys
import torch
import torchaudio
import onnxruntime as ort
import numpy as np
import argparse
from transformers import HubertModel, HubertConfig
class HubertONNXExporter:
"""Export and test HuBERT model to ONNX format"""
def __init__(self, model_path="GPT_SoVITS/pretrained_models/chinese-hubert-base", output_path="playground/hubert/chinese-hubert-base.onnx"):
self.model_path = model_path
self.onnx_path = output_path
self.model = None
self.config = None
def setup_model(self):
"""Configure and load the HuBERT model for ONNX export"""
# Configure for better ONNX compatibility
self.config = HubertConfig.from_pretrained(self.model_path)
self.config._attn_implementation = "eager" # Use standard attention
self.config.apply_spec_augment = False # Disable masking for inference
self.config.layerdrop = 0.0 # Disable layer dropout
# Load the model
self.model = HubertModel.from_pretrained(
self.model_path,
config=self.config,
local_files_only=True
)
self.model.eval()
def export_to_onnx(self, dummy_length=16000):
"""Export the model to ONNX format"""
if self.model is None:
raise ValueError("Model not initialized. Call setup_model() first.")
# Create dummy input (1 second at 16kHz)
dummy_input = torch.rand(1, dummy_length, dtype=torch.float32) - 0.5
# Export to ONNX
torch.onnx.export(
self.model,
dummy_input,
self.onnx_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['audio16k'],
output_names=['last_hidden_state'],
dynamic_axes={
'audio16k': {0: 'batch_size', 1: 'sequence_length'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
}
)
print(f"[Success] Model exported to {self.onnx_path}")
def test_onnx_export_exists(self):
"""Test that the ONNX model file was created"""
if os.path.exists(self.onnx_path):
print(f"[Success] ONNX model file exists at {self.onnx_path}")
return True
else:
print(f"[Error] ONNX model not found at {self.onnx_path}")
return False
def _load_and_preprocess_audio(self, audio_path, max_length=160000):
"""Load and preprocess audio file"""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Take first channel
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Limit length for testing (10 seconds at 16kHz)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
# make a zero tensor that has length 3200*0.3
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
# concate zero_tensor with waveform
waveform = torch.cat([waveform, zero_tensor], dim=1)
return waveform
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):
"""Test that ONNX model outputs match PyTorch model outputs"""
if not os.path.exists(audio_path):
print(f"[Skip] Test audio file not found at {audio_path}")
return False
if self.model is None:
raise ValueError("Model not initialized. Call setup_model() first.")
# Load and preprocess audio
waveform = self._load_and_preprocess_audio(audio_path)
# PyTorch inference
with torch.no_grad():
torch_output = self.model(waveform)
torch_hidden_states = torch_output.last_hidden_state
# ONNX inference
ort_session = ort.InferenceSession(self.onnx_path)
input_values = waveform.numpy().astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: input_values}
ort_outputs = ort_session.run(None, ort_inputs)
onnx_hidden_states = ort_outputs[0]
# Compare outputs
torch_numpy = torch_hidden_states.numpy()
diff = np.abs(torch_numpy - onnx_hidden_states).mean()
success = diff <= 1e-5
status = "[Success]" if success else "[Fail]"
print(f"{status} ONNX vs PyTorch comparison")
print(f" > mean_difference={diff}")
print(f" > torch_shape={torch_numpy.shape}")
print(f" > onnx_shape={onnx_hidden_states.shape}")
return success
def run_full_export_and_test(self):
"""Run the complete export and testing pipeline"""
print("Starting HuBERT ONNX export and testing...")
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(self.onnx_path), exist_ok=True)
# Setup model
self.setup_model()
# Export to ONNX
self.export_to_onnx()
# Test export
self.test_onnx_export_exists()
self.test_torch_vs_onnx()
print("Export and testing complete!")
def main():
"""Main execution function"""
parser = argparse.ArgumentParser(description="Export HuBERT model to ONNX format")
parser.add_argument(
"--model_path",
type=str,
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
help="Path to the HuBERT model directory (default: GPT_SoVITS/pretrained_models/chinese-hubert-base)"
)
parser.add_argument(
"--output_path",
type=str,
default="playground/hubert/chinese-hubert-base.onnx",
help="Output path for the ONNX model (default: playground/hubert/chinese-hubert-base.onnx)"
)
args = parser.parse_args()
exporter = HubertONNXExporter(model_path=args.model_path, output_path=args.output_path)
exporter.run_full_export_and_test()
if __name__ == "__main__":
main()

View File

@ -1,146 +0,0 @@
import onnxruntime as ort
import numpy as np
import onnx
from tqdm import tqdm
import torchaudio
import torch
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
MODEL_PATH = "onnx/v1_export/v1"
def audio_postprocess(
audios,
fragment_interval: float = 0.3,
):
zero_wav = np.zeros((int(32000 * fragment_interval),)).astype(np.float32)
for i, audio in enumerate(audios):
max_audio = np.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1:
audio /= max_audio
audio = np.concatenate([audio, zero_wav], axis=0)
audios[i] = audio
audio = np.concatenate(audios, axis=0)
# audio = (audio * 32768).astype(np.int16)
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
torchaudio.save('playground/output.wav', audio_tensor, 32000)
return audio
def load_audio(audio_path):
"""Load and preprocess audio file to 32k"""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 32kHz if needed
if sample_rate != 32000:
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
waveform = resampler(waveform)
# Take first channel
if waveform.shape[0] > 1:
waveform = waveform[0:1]
return waveform
def audio_preprocess(audio_path):
"""Get HuBERT features for the audio file"""
waveform = load_audio(audio_path)
ort_session = ort.InferenceSession(MODEL_PATH + "_export_audio_preprocess.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
[hubert_feature, spectrum, sv_emb] = ort_session.run(None, ort_inputs)
return hubert_feature, spectrum, sv_emb
def preprocess_text(text:str):
preprocessor = TextPreprocessorOnnx("playground/bert")
[phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v2')
phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0)
return phones, bert_features.T.astype(np.float32)
# input_phones_saved = np.load("playground/ref/input_phones.npy")
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追。")
# ref_phones = np.load("playground/ref/ref_phones.npy")
# ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
[ref_phones, ref_bert] = preprocess_text("近日江苏苏州荷花市集开张热闹与浪漫交织")
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
# audio_prompt_hubert_saved = np.load("playground/ref/audio_prompt_hubert.npy").astype(np.float32)
top_k = np.array([15], dtype=np.int64)
top_p = np.array([1.0], dtype=np.float32)
repetition_penalty = np.array([1.0], dtype=np.float32)
temperature = np.array([1.0], dtype=np.float32)
t2s_init_stage = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_stage.onnx")
# t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
[x, prompts, init_k, init_v, x_seq_len, y_seq_len] = t2s_init_stage.run(None, {
"input_text_phones": input_phones,
"input_text_bert": input_bert,
"ref_text_phones": ref_phones,
"ref_text_bert": ref_bert,
"hubert_ssl_content": audio_prompt_hubert,
})
empty_tensor = np.empty((1,0,512)).astype(np.float32)
t2s_stage_decoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_stage_decoder.onnx")
y, k, v, y_emb, logits, samples = t2s_stage_decoder.run(None, {
"ix": x,
"iy": prompts,
"ik": init_k,
"iv": init_v,
"iy_emb": empty_tensor,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"temperature": temperature,
"if_init_step": np.array([1]).astype(np.int64),
"x_seq_len": np.array([x_seq_len]).astype(np.int64),
"y_seq_len": np.array([y_seq_len]).astype(np.int64)
})
for idx in tqdm(range(1, 1500)):
k = np.pad(k, ((0,0), (0,1), (0,0), (0,0)))
v = np.pad(v, ((0,0), (0,1), (0,0), (0,0)))
y_seq_len = np.array([y.shape[1]]).astype(np.int64)
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
[y, k, v, y_emb, logits, samples] = t2s_stage_decoder.run(None, {
"ix": empty_tensor,
"iy": y,
"ik": k,
"iv": v,
"iy_emb": y_emb,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"temperature": temperature,
"if_init_step": np.array([0]).astype(np.int64),
"x_seq_len": np.array([x_seq_len]).astype(np.int64),
"y_seq_len": y_seq_len
})
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
break
y = y[:,:-1]
pred_semantic = np.expand_dims(y[:, -idx:], axis=0)
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
[audio] = vtis.run(None, {
"input_text_phones": input_phones,
"pred_semantic": pred_semantic,
"spectrum": spectrum.astype(np.float32),
# "sv_emb": sv_emb.astype(np.float32)
})
audio_postprocess([audio])

Binary file not shown.