mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:voice and text preprocess system verifed, todo:dissasemble onnx export of gsv
This commit is contained in:
parent
dd156f15aa
commit
5c08328cf3
233
GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py
Normal file
233
GPT_SoVITS/TTS_infer_pack/TextPreprocessor_onnx.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
now_dir = os.getcwd()
|
||||||
|
sys.path.append(now_dir)
|
||||||
|
|
||||||
|
import re
|
||||||
|
from text.LangSegmenter import LangSegmenter
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
from text.cleaner import clean_text
|
||||||
|
from text import cleaned_text_to_sequence
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||||
|
import onnxruntime as ort
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
|
|
||||||
|
language = os.environ.get("language", "Auto")
|
||||||
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
|
i18n = I18nAuto(language=language)
|
||||||
|
punctuation = set(["!", "?", "…", ",", ".", "-"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_first(text: str) -> str:
|
||||||
|
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||||
|
text = re.split(pattern, text)[0].strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||||
|
if (len(texts)) < 2:
|
||||||
|
return texts
|
||||||
|
result = []
|
||||||
|
text = ""
|
||||||
|
for ele in texts:
|
||||||
|
text += ele
|
||||||
|
if len(text) >= threshold:
|
||||||
|
result.append(text)
|
||||||
|
text = ""
|
||||||
|
if len(text) > 0:
|
||||||
|
if len(result) == 0:
|
||||||
|
result.append(text)
|
||||||
|
else:
|
||||||
|
result[len(result) - 1] += text
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TextPreprocessorOnnx:
|
||||||
|
def __init__(self, onnx_package: str):
|
||||||
|
self.bert_model = ort.InferenceSession(os.path.join(onnx_package, "chinese-roberta-wwm-ext-large.onnx"))
|
||||||
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=os.path.join(onnx_package, "tokenizer.json"))
|
||||||
|
|
||||||
|
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
|
text = self.replace_consecutive_punctuation(text)
|
||||||
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
|
result = []
|
||||||
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
|
for text in tqdm(texts):
|
||||||
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
|
if phones is None or norm_text == "":
|
||||||
|
continue
|
||||||
|
res = {
|
||||||
|
"phones": phones,
|
||||||
|
"bert_features": bert_features,
|
||||||
|
"norm_text": norm_text,
|
||||||
|
}
|
||||||
|
result.append(res)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||||
|
text = text.strip("\n")
|
||||||
|
if len(text) == 0:
|
||||||
|
return []
|
||||||
|
if text[0] not in splits and len(get_first(text)) < 4:
|
||||||
|
text = "。" + text if lang != "en" else "." + text
|
||||||
|
print(i18n("实际输入的目标文本:"))
|
||||||
|
print(text)
|
||||||
|
|
||||||
|
seg_method = get_seg_method(text_split_method)
|
||||||
|
text = seg_method(text)
|
||||||
|
|
||||||
|
while "\n\n" in text:
|
||||||
|
text = text.replace("\n\n", "\n")
|
||||||
|
|
||||||
|
_texts = text.split("\n")
|
||||||
|
_texts = self.filter_text(_texts)
|
||||||
|
_texts = merge_short_text_in_array(_texts, 5)
|
||||||
|
texts = []
|
||||||
|
|
||||||
|
for text in _texts:
|
||||||
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
|
if len(text.strip()) == 0:
|
||||||
|
continue
|
||||||
|
if not re.sub("\W+", "", text):
|
||||||
|
# 检测一下,如果是纯符号,就跳过。
|
||||||
|
continue
|
||||||
|
if text[-1] not in splits:
|
||||||
|
text += "。" if lang != "en" else "."
|
||||||
|
|
||||||
|
# 解决句子过长导致Bert报错的问题
|
||||||
|
if len(text) > 510:
|
||||||
|
texts.extend(split_big_text(text))
|
||||||
|
else:
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
print(i18n("实际输入的目标文本(切句后):"))
|
||||||
|
print(texts)
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def segment_and_extract_feature_for_text(
|
||||||
|
self, text: str, language: str, version: str = "v1"
|
||||||
|
) -> Tuple[list, np.ndarray, str]:
|
||||||
|
return self.get_phones_and_bert(text, language, version)
|
||||||
|
|
||||||
|
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||||
|
text = re.sub(r' {2,}', ' ', text)
|
||||||
|
textlist = []
|
||||||
|
langlist = []
|
||||||
|
if language == "all_zh":
|
||||||
|
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
elif language == "all_yue":
|
||||||
|
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||||
|
if tmp["lang"] == "zh":
|
||||||
|
tmp["lang"] = "yue"
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
elif language == "all_ja":
|
||||||
|
for tmp in LangSegmenter.getTexts(text,"ja"):
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
elif language == "all_ko":
|
||||||
|
for tmp in LangSegmenter.getTexts(text,"ko"):
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
elif language == "en":
|
||||||
|
langlist.append("en")
|
||||||
|
textlist.append(text)
|
||||||
|
elif language == "auto":
|
||||||
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
elif language == "auto_yue":
|
||||||
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
|
if tmp["lang"] == "zh":
|
||||||
|
tmp["lang"] = "yue"
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
else:
|
||||||
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
|
if langlist:
|
||||||
|
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||||
|
textlist[-1] += tmp["text"]
|
||||||
|
continue
|
||||||
|
if tmp["lang"] == "en":
|
||||||
|
langlist.append(tmp["lang"])
|
||||||
|
else:
|
||||||
|
# 因无法区别中日韩文汉字,以用户输入为准
|
||||||
|
langlist.append(language)
|
||||||
|
textlist.append(tmp["text"])
|
||||||
|
# print(textlist)
|
||||||
|
# print(langlist)
|
||||||
|
phones_list = []
|
||||||
|
bert_list = []
|
||||||
|
norm_text_list = []
|
||||||
|
for i in range(len(textlist)):
|
||||||
|
lang = langlist[i]
|
||||||
|
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||||
|
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||||
|
phones_list.append(phones)
|
||||||
|
norm_text_list.append(norm_text)
|
||||||
|
bert_list.append(bert)
|
||||||
|
bert = np.concatenate(bert_list, axis=1)
|
||||||
|
phones = sum(phones_list, [])
|
||||||
|
norm_text = "".join(norm_text_list)
|
||||||
|
|
||||||
|
if not final and len(phones) < 6:
|
||||||
|
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||||
|
|
||||||
|
return phones, bert, norm_text
|
||||||
|
|
||||||
|
def get_bert_feature(self, text: str, word2ph: list) -> np.ndarray:
|
||||||
|
inputs = self.tokenizer(text, return_tensors="np")
|
||||||
|
[res] = self.bert_model.run(None, {
|
||||||
|
"input_ids": inputs["input_ids"]
|
||||||
|
})
|
||||||
|
assert len(word2ph) == len(text)
|
||||||
|
phone_level_feature = []
|
||||||
|
for i in range(len(word2ph)):
|
||||||
|
repeat_feature = np.repeat(res[i:i+1], word2ph[i], axis=0)
|
||||||
|
phone_level_feature.append(repeat_feature)
|
||||||
|
phone_level_feature = np.concatenate(phone_level_feature, axis=0)
|
||||||
|
return phone_level_feature.T
|
||||||
|
|
||||||
|
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||||
|
language = language.replace("all_", "")
|
||||||
|
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||||
|
phones = cleaned_text_to_sequence(phones, version)
|
||||||
|
return phones, word2ph, norm_text
|
||||||
|
|
||||||
|
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||||
|
language = language.replace("all_", "")
|
||||||
|
if language == "zh":
|
||||||
|
feature = self.get_bert_feature(norm_text, word2ph)
|
||||||
|
else:
|
||||||
|
feature = np.zeros(
|
||||||
|
(1024, len(phones)),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
return feature
|
||||||
|
|
||||||
|
def filter_text(self, texts):
|
||||||
|
_text = []
|
||||||
|
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||||
|
raise ValueError(i18n("请输入有效文本"))
|
||||||
|
for text in texts:
|
||||||
|
if text in [None, " ", ""]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_text.append(text)
|
||||||
|
return _text
|
||||||
|
|
||||||
|
def replace_consecutive_punctuation(self, text):
|
||||||
|
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||||
|
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||||
|
result = re.sub(pattern, r"\1", text)
|
||||||
|
return result
|
@ -402,15 +402,16 @@ if __name__ == "__main__":
|
|||||||
# version = "v2"
|
# version = "v2"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
|
||||||
# exp_path = "v2pro_export"
|
|
||||||
# version = "v2Pro"
|
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||||
exp_path = "v2proplus_export"
|
exp_path = "v2pro_export"
|
||||||
version = "v2ProPlus"
|
version = "v2Pro"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
|
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
|
# exp_path = "v2proplus_export"
|
||||||
|
# version = "v2ProPlus"
|
||||||
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@ import onnx
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import torch
|
import torch
|
||||||
|
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATH = "playground/v2proplus_export/v2proplus"
|
MODEL_PATH = "playground/v2proplus_export/v2proplus"
|
||||||
|
|
||||||
@ -30,7 +32,7 @@ def audio_postprocess(
|
|||||||
return audio
|
return audio
|
||||||
|
|
||||||
def load_and_preprocess_audio(audio_path, max_length=160000):
|
def load_and_preprocess_audio(audio_path, max_length=160000):
|
||||||
"""Load and preprocess audio file"""
|
"""Load and preprocess audio file to 16k"""
|
||||||
waveform, sample_rate = torchaudio.load(audio_path)
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
|
||||||
# Resample to 16kHz if needed
|
# Resample to 16kHz if needed
|
||||||
@ -49,8 +51,6 @@ def load_and_preprocess_audio(audio_path, max_length=160000):
|
|||||||
# make a zero tensor that has length 3200*0.3
|
# make a zero tensor that has length 3200*0.3
|
||||||
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
|
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
|
||||||
|
|
||||||
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
|
|
||||||
|
|
||||||
# concate zero_tensor with waveform
|
# concate zero_tensor with waveform
|
||||||
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||||
|
|
||||||
@ -64,13 +64,25 @@ def get_audio_hubert(audio_path):
|
|||||||
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
||||||
# transpose axis 1 and 2 with numpy
|
# transpose axis 1 and 2 with numpy
|
||||||
hubert_feature = hubert_feature.transpose(0, 2, 1)
|
hubert_feature = hubert_feature.transpose(0, 2, 1)
|
||||||
print("Hubert feature shape:", hubert_feature.shape)
|
|
||||||
return hubert_feature
|
return hubert_feature
|
||||||
|
|
||||||
input_phones = np.load("playground/ref/input_phones.npy")
|
def preprocess_text(text:str):
|
||||||
input_bert = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
preprocessor = TextPreprocessorOnnx("playground/bert")
|
||||||
ref_phones = np.load("playground/ref/ref_phones.npy")
|
[phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v2')
|
||||||
ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
|
phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0)
|
||||||
|
return phones, bert_features.T.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
||||||
|
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
||||||
|
[input_phones, input_bert] = preprocess_text("震撼视角,感受成都世运会,闭幕式烟花")
|
||||||
|
|
||||||
|
|
||||||
|
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
||||||
|
# ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
|
||||||
|
[ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织")
|
||||||
|
|
||||||
|
|
||||||
audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
|
audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
|
||||||
|
|
||||||
|
|
||||||
@ -84,8 +96,6 @@ encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
|||||||
"ssl_content": audio_prompt_hubert
|
"ssl_content": audio_prompt_hubert
|
||||||
})
|
})
|
||||||
|
|
||||||
print(x.shape, prompts.shape)
|
|
||||||
|
|
||||||
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
|
fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
|
||||||
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
||||||
|
|
||||||
@ -124,6 +134,7 @@ if sample_rate != 32000:
|
|||||||
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=32000)
|
||||||
waveform = resampler(waveform)
|
waveform = resampler(waveform)
|
||||||
ref_audio = waveform.numpy().astype(np.float32)
|
ref_audio = waveform.numpy().astype(np.float32)
|
||||||
|
|
||||||
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
||||||
|
|
||||||
[audio] = vtis.run(None, {
|
[audio] = vtis.run(None, {
|
||||||
@ -131,6 +142,5 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
|||||||
"pred_semantic": pred_semantic,
|
"pred_semantic": pred_semantic,
|
||||||
"ref_audio": ref_audio
|
"ref_audio": ref_audio
|
||||||
})
|
})
|
||||||
print(audio.shape, audio.dtype, audio.min(), audio.max())
|
|
||||||
|
|
||||||
audio_postprocess([audio])
|
audio_postprocess([audio])
|
||||||
|
Binary file not shown.
@ -33,7 +33,7 @@ rotary_embedding_torch
|
|||||||
ToJyutping
|
ToJyutping
|
||||||
g2pk2
|
g2pk2
|
||||||
ko_pron
|
ko_pron
|
||||||
opencc
|
opencc==1.1.6
|
||||||
python_mecab_ko; sys_platform != 'win32'
|
python_mecab_ko; sys_platform != 'win32'
|
||||||
fastapi[standard]>=0.115.2
|
fastapi[standard]>=0.115.2
|
||||||
x_transformers
|
x_transformers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user