mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat: sampling params working now for export, todo:fold weights clean code
This commit is contained in:
parent
9ed42daa88
commit
b45cbc3561
@ -61,7 +61,7 @@ def logits_to_probs(
|
|||||||
)
|
)
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
logits = logits / max(temperature, 1e-5)
|
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=temperature.device, dtype=temperature.dtype))
|
||||||
|
|
||||||
# if top_k is not None: # To be captured by onnx
|
# if top_k is not None: # To be captured by onnx
|
||||||
v, _ = torch.topk(logits, top_k)
|
v, _ = torch.topk(logits, top_k)
|
||||||
|
@ -94,14 +94,14 @@ class T2SInitStep(nn.Module):
|
|||||||
self.fsdc = t2s.first_stage_decoder
|
self.fsdc = t2s.first_stage_decoder
|
||||||
self.vits = vits
|
self.vits = vits
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
codes = self.vits.extract_latent(ssl_content)
|
codes = self.vits.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
||||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
|
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
fake_logits = torch.randn((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
fake_logits = torch.randn((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
||||||
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
||||||
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
||||||
@ -111,8 +111,8 @@ class T2SStageStep(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.stage_decoder = stage_decoder
|
self.stage_decoder = stage_decoder
|
||||||
|
|
||||||
def forward(self, iy, ik, iv, iy_emb, ix_example):
|
def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example)
|
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
fake_x_example = torch.randn((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
fake_x_example = torch.randn((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
||||||
return y, k, v, y_emb, fake_x_example, logits, samples
|
return y, k, v, y_emb, fake_x_example, logits, samples
|
||||||
|
|
||||||
@ -136,38 +136,27 @@ class T2SModel(nn.Module):
|
|||||||
self.stage_decoder = self.t2s_model.stage_decoder
|
self.stage_decoder = self.t2s_model.stage_decoder
|
||||||
# self.t2s_model = torch.jit.script(self.t2s_model)
|
# self.t2s_model = torch.jit.script(self.t2s_model)
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
early_stop_num = self.t2s_model.early_stop_num
|
early_stop_num = self.t2s_model.early_stop_num
|
||||||
|
|
||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
|
|
||||||
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
print(logits.shape, samples.shape)
|
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
break
|
break
|
||||||
y[0, -1] = 0
|
|
||||||
|
|
||||||
return y[:, -idx:].unsqueeze(0)
|
return y[:, -idx:].unsqueeze(0)
|
||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
# self.init_step = torch.jit.script(self.init_step)
|
|
||||||
if dynamo:
|
|
||||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
|
||||||
init_step_export_output = torch.onnx.dynamo_export(
|
|
||||||
self.init_step, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
|
||||||
)
|
|
||||||
init_step_export_output.save(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
|
||||||
return
|
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.init_step,
|
self.init_step,
|
||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
||||||
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"],
|
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature"],
|
||||||
output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'],
|
output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"ref_text_phones": {1: "ref_length"},
|
"ref_text_phones": {1: "ref_length"},
|
||||||
@ -178,14 +167,14 @@ class T2SModel(nn.Module):
|
|||||||
},
|
},
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
)
|
)
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
|
|
||||||
stage_step = T2SStageStep(self.stage_decoder)
|
stage_step = T2SStageStep(self.stage_decoder)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
stage_step,
|
stage_step,
|
||||||
(y, k, v, y_emb, x_example),
|
(y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
||||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature"],
|
||||||
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"iy": {1: "iy_length"},
|
"iy": {1: "iy_length"},
|
||||||
@ -237,14 +226,14 @@ class GptSoVits(nn.Module):
|
|||||||
self.vits = vits
|
self.vits = vits
|
||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb)
|
audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.vits,
|
self.vits,
|
||||||
(text_seq, pred_semantic, spectrum, sv_emb),
|
(text_seq, pred_semantic, spectrum, sv_emb),
|
||||||
@ -304,6 +293,14 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi
|
|||||||
data_inputs_init = [input for input in init_step_model.graph.input]
|
data_inputs_init = [input for input in init_step_model.graph.input]
|
||||||
data_inputs_stage = [input for input in stage_step_model.graph.input]
|
data_inputs_stage = [input for input in stage_step_model.graph.input]
|
||||||
|
|
||||||
|
# Get all names from both lists
|
||||||
|
names_list_init = {obj.name for obj in data_inputs_init}
|
||||||
|
names_list_stage = {obj.name for obj in data_inputs_stage}
|
||||||
|
# Find names that appear in both lists
|
||||||
|
repeated_input_names = names_list_init.intersection(names_list_stage)
|
||||||
|
# Filter out objects with repeated names
|
||||||
|
data_inputs_stage = [obj for obj in data_inputs_stage if obj.name not in repeated_input_names]
|
||||||
|
|
||||||
del then_graph.input[:]
|
del then_graph.input[:]
|
||||||
del else_graph.input[:]
|
del else_graph.input[:]
|
||||||
|
|
||||||
@ -413,13 +410,17 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5
|
ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5
|
||||||
|
top_k = torch.LongTensor([15])
|
||||||
|
top_p = torch.FloatTensor([1.0])
|
||||||
|
repetition_penalty = torch.FloatTensor([1.0])
|
||||||
|
temperature = torch.FloatTensor([1.0])
|
||||||
|
|
||||||
os.makedirs(f"onnx/{project_name}", exist_ok=True)
|
os.makedirs(f"onnx/{project_name}", exist_ok=True)
|
||||||
|
|
||||||
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
||||||
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float())
|
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
# exit()
|
# exit()
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name)
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
|
|
||||||
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
||||||
input_names=["audio32k"],
|
input_names=["audio32k"],
|
||||||
@ -444,12 +445,12 @@ if __name__ == "__main__":
|
|||||||
# version = "v1"
|
# version = "v1"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
# vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||||
exp_path = "v2_export"
|
# exp_path = "v2_export"
|
||||||
version = "v2"
|
# version = "v2"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
combineInitStepAndStageStep('onnx/v2_export/v2_export_t2s_init_step.onnx', 'onnx/v2_export/v2_export_t2s_sdec.onnx', 'onnx/v2_export/v2_export_t2s_combined.onnx')
|
# combineInitStepAndStageStep('onnx/v2_export/v2_export_t2s_init_step.onnx', 'onnx/v2_export/v2_export_t2s_sdec.onnx', 'onnx/v2_export/v2_export_t2s_combined.onnx')
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
# gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||||
@ -457,11 +458,11 @@ if __name__ == "__main__":
|
|||||||
# version = "v2Pro"
|
# version = "v2Pro"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
# exp_path = "v2proplus_export"
|
exp_path = "v2proplus_export"
|
||||||
# version = "v2ProPlus"
|
version = "v2ProPlus"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
# combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx')
|
combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx')
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATH = "onnx/v2_export/v2"
|
MODEL_PATH = "onnx/v2proplus_export/v2proplus"
|
||||||
|
|
||||||
def audio_postprocess(
|
def audio_postprocess(
|
||||||
audios,
|
audios,
|
||||||
@ -73,10 +73,13 @@ def preprocess_text(text:str):
|
|||||||
|
|
||||||
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
||||||
|
|
||||||
np.save("playground/ref/audio_prompt_hubert.npy", audio_prompt_hubert.astype(np.float16))
|
|
||||||
|
|
||||||
# audio_prompt_hubert_saved = np.load("playground/ref/audio_prompt_hubert.npy").astype(np.float32)
|
# audio_prompt_hubert_saved = np.load("playground/ref/audio_prompt_hubert.npy").astype(np.float32)
|
||||||
|
|
||||||
|
top_k = np.array([15], dtype=np.int64)
|
||||||
|
top_p = np.array([1.0], dtype=np.float32)
|
||||||
|
repetition_penalty = np.array([1.0], dtype=np.float32)
|
||||||
|
temperature = np.array([1.0], dtype=np.float32)
|
||||||
|
|
||||||
t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx")
|
t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx")
|
||||||
# t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
# t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
||||||
|
|
||||||
@ -91,7 +94,11 @@ t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx")
|
|||||||
"ik":np.empty((24, 0, 1, 512), dtype=np.float32),
|
"ik":np.empty((24, 0, 1, 512), dtype=np.float32),
|
||||||
"iv":np.empty((24, 0, 1, 512), dtype=np.float32),
|
"iv":np.empty((24, 0, 1, 512), dtype=np.float32),
|
||||||
"iy_emb":np.empty((1, 0, 512), dtype=np.float32),
|
"iy_emb":np.empty((1, 0, 512), dtype=np.float32),
|
||||||
"ix_example":np.empty((1, 0), dtype=np.float32)
|
"ix_example":np.empty((1, 0), dtype=np.float32),
|
||||||
|
"top_k": top_k,
|
||||||
|
"top_p": top_p,
|
||||||
|
"repetition_penalty": repetition_penalty,
|
||||||
|
"temperature": temperature
|
||||||
})
|
})
|
||||||
|
|
||||||
# t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
# t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
||||||
@ -109,7 +116,11 @@ for idx in tqdm(range(1, 1500)):
|
|||||||
"ik": k,
|
"ik": k,
|
||||||
"iv": v,
|
"iv": v,
|
||||||
"iy_emb": y_emb,
|
"iy_emb": y_emb,
|
||||||
"ix_example": x_example
|
"ix_example": x_example,
|
||||||
|
"top_k": top_k,
|
||||||
|
"top_p": top_p,
|
||||||
|
"repetition_penalty": repetition_penalty,
|
||||||
|
"temperature": temperature
|
||||||
})
|
})
|
||||||
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
||||||
break
|
break
|
||||||
@ -124,7 +135,7 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
|||||||
"input_text_phones": input_phones,
|
"input_text_phones": input_phones,
|
||||||
"pred_semantic": pred_semantic,
|
"pred_semantic": pred_semantic,
|
||||||
"spectrum": spectrum.astype(np.float32),
|
"spectrum": spectrum.astype(np.float32),
|
||||||
# "sv_emb": sv_emb.astype(np.float32)
|
"sv_emb": sv_emb.astype(np.float32)
|
||||||
})
|
})
|
||||||
|
|
||||||
audio_postprocess([audio])
|
audio_postprocess([audio])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user