mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:remove unneed for main
This commit is contained in:
parent
968ac4c264
commit
fa84e262ae
@ -1,233 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
import re
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from typing import Dict, List, Tuple
|
||||
from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-"])
|
||||
|
||||
|
||||
def get_first(text: str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
|
||||
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
text = ""
|
||||
for ele in texts:
|
||||
text += ele
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if len(text) > 0:
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
result[len(result) - 1] += text
|
||||
return result
|
||||
|
||||
|
||||
class TextPreprocessorOnnx:
|
||||
def __init__(self, onnx_package: str):
|
||||
self.bert_model = ort.InferenceSession(os.path.join(onnx_package, "chinese-roberta-wwm-ext-large.onnx"))
|
||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=os.path.join(onnx_package, "tokenizer.json"))
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
if phones is None or norm_text == "":
|
||||
continue
|
||||
res = {
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
}
|
||||
result.append(res)
|
||||
return result
|
||||
|
||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||
text = text.strip("\n")
|
||||
if len(text) == 0:
|
||||
return []
|
||||
if text[0] not in splits and len(get_first(text)) < 4:
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
|
||||
seg_method = get_seg_method(text_split_method)
|
||||
text = seg_method(text)
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
|
||||
_texts = text.split("\n")
|
||||
_texts = self.filter_text(_texts)
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if len(text.strip()) == 0:
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if text[-1] not in splits:
|
||||
text += "。" if lang != "en" else "."
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
if len(text) > 510:
|
||||
texts.extend(split_big_text(text))
|
||||
else:
|
||||
texts.append(text)
|
||||
|
||||
print(i18n("实际输入的目标文本(切句后):"))
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1"
|
||||
) -> Tuple[list, np.ndarray, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text,"ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text,"ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = np.concatenate(bert_list, axis=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||
|
||||
return phones, bert, norm_text
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> np.ndarray:
|
||||
inputs = self.tokenizer(text, return_tensors="np")
|
||||
[res] = self.bert_model.run(None, {
|
||||
"input_ids": inputs["input_ids"]
|
||||
})
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = np.repeat(res[i:i+1], word2ph[i], axis=0)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = np.concatenate(phone_level_feature, axis=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||
language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph)
|
||||
else:
|
||||
feature = np.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return feature
|
||||
|
||||
def filter_text(self, texts):
|
||||
_text = []
|
||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||
raise ValueError(i18n("请输入有效文本"))
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
pass
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
def replace_consecutive_punctuation(self, text):
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
@ -1,178 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torchaudio
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
import argparse
|
||||
from transformers import HubertModel, HubertConfig
|
||||
|
||||
|
||||
class HubertONNXExporter:
|
||||
"""Export and test HuBERT model to ONNX format"""
|
||||
|
||||
def __init__(self, model_path="GPT_SoVITS/pretrained_models/chinese-hubert-base", output_path="playground/hubert/chinese-hubert-base.onnx"):
|
||||
self.model_path = model_path
|
||||
self.onnx_path = output_path
|
||||
self.model = None
|
||||
self.config = None
|
||||
|
||||
def setup_model(self):
|
||||
"""Configure and load the HuBERT model for ONNX export"""
|
||||
# Configure for better ONNX compatibility
|
||||
self.config = HubertConfig.from_pretrained(self.model_path)
|
||||
self.config._attn_implementation = "eager" # Use standard attention
|
||||
self.config.apply_spec_augment = False # Disable masking for inference
|
||||
self.config.layerdrop = 0.0 # Disable layer dropout
|
||||
|
||||
# Load the model
|
||||
self.model = HubertModel.from_pretrained(
|
||||
self.model_path,
|
||||
config=self.config,
|
||||
local_files_only=True
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def export_to_onnx(self, dummy_length=16000):
|
||||
"""Export the model to ONNX format"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model not initialized. Call setup_model() first.")
|
||||
|
||||
# Create dummy input (1 second at 16kHz)
|
||||
dummy_input = torch.rand(1, dummy_length, dtype=torch.float32) - 0.5
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
self.onnx_path,
|
||||
export_params=True,
|
||||
opset_version=11,
|
||||
do_constant_folding=True,
|
||||
input_names=['audio16k'],
|
||||
output_names=['last_hidden_state'],
|
||||
dynamic_axes={
|
||||
'audio16k': {0: 'batch_size', 1: 'sequence_length'},
|
||||
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
|
||||
}
|
||||
)
|
||||
print(f"[Success] Model exported to {self.onnx_path}")
|
||||
|
||||
def test_onnx_export_exists(self):
|
||||
"""Test that the ONNX model file was created"""
|
||||
if os.path.exists(self.onnx_path):
|
||||
print(f"[Success] ONNX model file exists at {self.onnx_path}")
|
||||
return True
|
||||
else:
|
||||
print(f"[Error] ONNX model not found at {self.onnx_path}")
|
||||
return False
|
||||
|
||||
def _load_and_preprocess_audio(self, audio_path, max_length=160000):
|
||||
"""Load and preprocess audio file"""
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
# Take first channel
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0:1]
|
||||
|
||||
# Limit length for testing (10 seconds at 16kHz)
|
||||
if waveform.shape[1] > max_length:
|
||||
waveform = waveform[:, :max_length]
|
||||
|
||||
# make a zero tensor that has length 3200*0.3
|
||||
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
|
||||
|
||||
print("waveform shape and zero wave shape", waveform.shape, zero_tensor.shape)
|
||||
|
||||
# concate zero_tensor with waveform
|
||||
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||
|
||||
return waveform
|
||||
|
||||
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):
|
||||
"""Test that ONNX model outputs match PyTorch model outputs"""
|
||||
if not os.path.exists(audio_path):
|
||||
print(f"[Skip] Test audio file not found at {audio_path}")
|
||||
return False
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Model not initialized. Call setup_model() first.")
|
||||
|
||||
# Load and preprocess audio
|
||||
waveform = self._load_and_preprocess_audio(audio_path)
|
||||
|
||||
# PyTorch inference
|
||||
with torch.no_grad():
|
||||
torch_output = self.model(waveform)
|
||||
torch_hidden_states = torch_output.last_hidden_state
|
||||
|
||||
# ONNX inference
|
||||
ort_session = ort.InferenceSession(self.onnx_path)
|
||||
input_values = waveform.numpy().astype(np.float32)
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: input_values}
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
onnx_hidden_states = ort_outputs[0]
|
||||
|
||||
# Compare outputs
|
||||
torch_numpy = torch_hidden_states.numpy()
|
||||
diff = np.abs(torch_numpy - onnx_hidden_states).mean()
|
||||
|
||||
success = diff <= 1e-5
|
||||
status = "[Success]" if success else "[Fail]"
|
||||
|
||||
print(f"{status} ONNX vs PyTorch comparison")
|
||||
print(f" > mean_difference={diff}")
|
||||
print(f" > torch_shape={torch_numpy.shape}")
|
||||
print(f" > onnx_shape={onnx_hidden_states.shape}")
|
||||
|
||||
return success
|
||||
|
||||
def run_full_export_and_test(self):
|
||||
"""Run the complete export and testing pipeline"""
|
||||
print("Starting HuBERT ONNX export and testing...")
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.onnx_path), exist_ok=True)
|
||||
|
||||
# Setup model
|
||||
self.setup_model()
|
||||
|
||||
# Export to ONNX
|
||||
self.export_to_onnx()
|
||||
|
||||
# Test export
|
||||
self.test_onnx_export_exists()
|
||||
self.test_torch_vs_onnx()
|
||||
|
||||
print("Export and testing complete!")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main execution function"""
|
||||
parser = argparse.ArgumentParser(description="Export HuBERT model to ONNX format")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
help="Path to the HuBERT model directory (default: GPT_SoVITS/pretrained_models/chinese-hubert-base)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="playground/hubert/chinese-hubert-base.onnx",
|
||||
help="Output path for the ONNX model (default: playground/hubert/chinese-hubert-base.onnx)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
exporter = HubertONNXExporter(model_path=args.model_path, output_path=args.output_path)
|
||||
exporter.run_full_export_and_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,146 +0,0 @@
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
import onnx
|
||||
from tqdm import tqdm
|
||||
import torchaudio
|
||||
import torch
|
||||
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||
|
||||
|
||||
MODEL_PATH = "onnx/v1_export/v1"
|
||||
|
||||
def audio_postprocess(
|
||||
audios,
|
||||
fragment_interval: float = 0.3,
|
||||
):
|
||||
zero_wav = np.zeros((int(32000 * fragment_interval),)).astype(np.float32)
|
||||
for i, audio in enumerate(audios):
|
||||
max_audio = np.abs(audio).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
audio /= max_audio
|
||||
audio = np.concatenate([audio, zero_wav], axis=0)
|
||||
audios[i] = audio
|
||||
|
||||
audio = np.concatenate(audios, axis=0)
|
||||
|
||||
# audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
|
||||
|
||||
torchaudio.save('playground/output.wav', audio_tensor, 32000)
|
||||
|
||||
return audio
|
||||
|
||||
def load_audio(audio_path):
|
||||
"""Load and preprocess audio file to 32k"""
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# Resample to 32kHz if needed
|
||||
if sample_rate != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
# Take first channel
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0:1]
|
||||
|
||||
return waveform
|
||||
|
||||
def audio_preprocess(audio_path):
|
||||
"""Get HuBERT features for the audio file"""
|
||||
waveform = load_audio(audio_path)
|
||||
ort_session = ort.InferenceSession(MODEL_PATH + "_export_audio_preprocess.onnx")
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
||||
[hubert_feature, spectrum, sv_emb] = ort_session.run(None, ort_inputs)
|
||||
return hubert_feature, spectrum, sv_emb
|
||||
|
||||
def preprocess_text(text:str):
|
||||
preprocessor = TextPreprocessorOnnx("playground/bert")
|
||||
[phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v2')
|
||||
phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0)
|
||||
return phones, bert_features.T.astype(np.float32)
|
||||
|
||||
|
||||
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
||||
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
||||
[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追。")
|
||||
|
||||
|
||||
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
||||
# ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32)
|
||||
[ref_phones, ref_bert] = preprocess_text("近日江苏苏州荷花市集开张热闹与浪漫交织")
|
||||
|
||||
|
||||
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
||||
|
||||
# audio_prompt_hubert_saved = np.load("playground/ref/audio_prompt_hubert.npy").astype(np.float32)
|
||||
|
||||
top_k = np.array([15], dtype=np.int64)
|
||||
top_p = np.array([1.0], dtype=np.float32)
|
||||
repetition_penalty = np.array([1.0], dtype=np.float32)
|
||||
temperature = np.array([1.0], dtype=np.float32)
|
||||
|
||||
t2s_init_stage = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_stage.onnx")
|
||||
# t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
||||
|
||||
[x, prompts, init_k, init_v, x_seq_len, y_seq_len] = t2s_init_stage.run(None, {
|
||||
"input_text_phones": input_phones,
|
||||
"input_text_bert": input_bert,
|
||||
"ref_text_phones": ref_phones,
|
||||
"ref_text_bert": ref_bert,
|
||||
"hubert_ssl_content": audio_prompt_hubert,
|
||||
})
|
||||
empty_tensor = np.empty((1,0,512)).astype(np.float32)
|
||||
|
||||
t2s_stage_decoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_stage_decoder.onnx")
|
||||
y, k, v, y_emb, logits, samples = t2s_stage_decoder.run(None, {
|
||||
"ix": x,
|
||||
"iy": prompts,
|
||||
"ik": init_k,
|
||||
"iv": init_v,
|
||||
"iy_emb": empty_tensor,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"temperature": temperature,
|
||||
"if_init_step": np.array([1]).astype(np.int64),
|
||||
"x_seq_len": np.array([x_seq_len]).astype(np.int64),
|
||||
"y_seq_len": np.array([y_seq_len]).astype(np.int64)
|
||||
})
|
||||
|
||||
for idx in tqdm(range(1, 1500)):
|
||||
k = np.pad(k, ((0,0), (0,1), (0,0), (0,0)))
|
||||
v = np.pad(v, ((0,0), (0,1), (0,0), (0,0)))
|
||||
y_seq_len = np.array([y.shape[1]]).astype(np.int64)
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
[y, k, v, y_emb, logits, samples] = t2s_stage_decoder.run(None, {
|
||||
"ix": empty_tensor,
|
||||
"iy": y,
|
||||
"ik": k,
|
||||
"iv": v,
|
||||
"iy_emb": y_emb,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"temperature": temperature,
|
||||
"if_init_step": np.array([0]).astype(np.int64),
|
||||
"x_seq_len": np.array([x_seq_len]).astype(np.int64),
|
||||
"y_seq_len": y_seq_len
|
||||
})
|
||||
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
||||
break
|
||||
y = y[:,:-1]
|
||||
|
||||
|
||||
pred_semantic = np.expand_dims(y[:, -idx:], axis=0)
|
||||
|
||||
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
||||
|
||||
[audio] = vtis.run(None, {
|
||||
"input_text_phones": input_phones,
|
||||
"pred_semantic": pred_semantic,
|
||||
"spectrum": spectrum.astype(np.float32),
|
||||
# "sv_emb": sv_emb.astype(np.float32)
|
||||
})
|
||||
|
||||
audio_postprocess([audio])
|
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user