mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat: solved problem, export works
This commit is contained in:
parent
419909b443
commit
968ac4c264
@ -105,36 +105,29 @@ class DictToAttrRecursive(dict):
|
|||||||
raise AttributeError(f"Attribute {item} not found")
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
|
||||||
class T2SInitStep(nn.Module):
|
class T2SInitStage(nn.Module):
|
||||||
def __init__(self, t2s, vits):
|
def __init__(self, t2s, vits):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = t2s.onnx_encoder
|
self.encoder = t2s.onnx_encoder
|
||||||
self.fsdc = t2s.first_stage_decoder
|
|
||||||
self.vits = vits
|
self.vits = vits
|
||||||
|
self.num_layers = t2s.num_layers
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||||
first_infer = first_infer.to(torch.int64)
|
|
||||||
codes = self.vits.extract_latent(ssl_content)
|
codes = self.vits.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
||||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer)
|
x = self.encoder(all_phoneme_ids, bert)
|
||||||
fake_logits = torch.zeros((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
|
||||||
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
|
||||||
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
|
||||||
|
|
||||||
class T2SStageStep(nn.Module):
|
x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1]
|
||||||
def __init__(self, stage_decoder):
|
y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1]
|
||||||
super().__init__()
|
|
||||||
self.stage_decoder = stage_decoder
|
|
||||||
|
|
||||||
def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None):
|
init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
||||||
first_infer = first_infer.to(torch.int64)
|
init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float)
|
||||||
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer)
|
|
||||||
fake_x_example = torch.zeros((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
return x, prompt, init_k, init_v, x_seq_len, y_seq_len
|
||||||
return y, k, v, y_emb, fake_x_example, logits, samples
|
|
||||||
|
|
||||||
class T2SModel(nn.Module):
|
class T2SModel(nn.Module):
|
||||||
def __init__(self, t2s_path, vits_model):
|
def __init__(self, t2s_path, vits_model):
|
||||||
@ -151,17 +144,26 @@ class T2SModel(nn.Module):
|
|||||||
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||||
self.t2s_model = self.t2s_model.model
|
self.t2s_model = self.t2s_model.model
|
||||||
self.t2s_model.init_onnx()
|
self.t2s_model.init_onnx()
|
||||||
self.init_step = T2SInitStep(self.t2s_model, self.vits_model)
|
self.init_stage = T2SInitStage(self.t2s_model, self.vits_model)
|
||||||
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
|
||||||
self.stage_decoder = self.t2s_model.stage_decoder
|
self.stage_decoder = self.t2s_model.stage_decoder
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([1]))
|
empty_tensor = torch.empty((1,0,512)).to(torch.float)
|
||||||
|
# first step
|
||||||
|
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v,
|
||||||
|
empty_tensor,
|
||||||
|
top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature,
|
||||||
|
first_infer=torch.LongTensor([1]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
|
||||||
|
|
||||||
for idx in range(5): # This is a fake one! DO NOT take this as reference
|
for idx in range(5): # This is a fake one! DO NOT take this as reference
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([0]))
|
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
|
||||||
y, k, v, y_emb, logits, samples = enco
|
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
|
||||||
|
y_seq_len = y.shape[1]
|
||||||
|
y, k, v, y_emb, logits, samples = self.stage_decoder(empty_tensor, y, k, v,
|
||||||
|
y_emb,
|
||||||
|
top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature,
|
||||||
|
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
|
||||||
# if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
# if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
# break
|
# break
|
||||||
|
|
||||||
@ -169,11 +171,11 @@ class T2SModel(nn.Module):
|
|||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.init_step,
|
self.init_stage,
|
||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature, torch.Tensor([True]).to(torch.bool)),
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx",
|
||||||
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"],
|
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"],
|
||||||
output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'],
|
output_names=["x", "prompt", "init_k", "init_v", 'x_seq_len', 'y_seq_len'],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"ref_text_phones": {1: "ref_length"},
|
"ref_text_phones": {1: "ref_length"},
|
||||||
"input_text_phones": {1: "text_length"},
|
"input_text_phones": {1: "text_length"},
|
||||||
@ -184,28 +186,38 @@ class T2SModel(nn.Module):
|
|||||||
opset_version=16,
|
opset_version=16,
|
||||||
do_constant_folding=False
|
do_constant_folding=False
|
||||||
)
|
)
|
||||||
# simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx")
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.Tensor([True]).to(torch.bool))
|
x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
empty_tensor = torch.empty((1,0,512)).to(torch.float)
|
||||||
|
x_seq_len = torch.Tensor([x_seq_len]).to(torch.int64)
|
||||||
|
y_seq_len = torch.Tensor([y_seq_len]).to(torch.int64)
|
||||||
|
|
||||||
|
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v,
|
||||||
|
empty_tensor,
|
||||||
|
top_k, top_p, repetition_penalty, temperature,
|
||||||
|
torch.LongTensor([1]), x_seq_len, y_seq_len)
|
||||||
|
print(y.shape, k.shape, v.shape, y_emb.shape, logits.shape, samples.shape)
|
||||||
|
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
|
||||||
|
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
|
||||||
|
y_seq_len = torch.Tensor([y.shape[1]]).to(torch.int64)
|
||||||
|
|
||||||
stage_step = T2SStageStep(self.stage_decoder)
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
stage_step,
|
self.stage_decoder,
|
||||||
(y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature, torch.Tensor([False]).to(torch.bool)),
|
(x, y, k, v, y_emb, top_k, top_p, repetition_penalty, temperature, torch.LongTensor([0]), x_seq_len, y_seq_len),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx",
|
||||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"],
|
input_names=["ix", "iy", "ik", "iv", "iy_emb", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step", "x_seq_len", "y_seq_len"],
|
||||||
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
|
"ix": {1: "ix_length"},
|
||||||
"iy": {1: "iy_length"},
|
"iy": {1: "iy_length"},
|
||||||
"ik": {1: "ik_length"},
|
"ik": {1: "ik_length"},
|
||||||
"iv": {1: "iv_length"},
|
"iv": {1: "iv_length"},
|
||||||
"iy_emb": {1: "iy_emb_length"},
|
"iy_emb": {1: "iy_emb_length"},
|
||||||
"ix_example": {1: "ix_example_length"},
|
|
||||||
"x_example": {1: "x_example_length"}
|
|
||||||
},
|
},
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
)
|
)
|
||||||
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
|
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx")
|
||||||
|
|
||||||
|
|
||||||
class VitsModel(nn.Module):
|
class VitsModel(nn.Module):
|
||||||
@ -301,73 +313,7 @@ class AudioPreprocess(nn.Module):
|
|||||||
|
|
||||||
return ssl_content, spectrum, sv_emb
|
return ssl_content, spectrum, sv_emb
|
||||||
|
|
||||||
def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combined_onnx_path):
|
def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_preprocessor=True, half_precision=False):
|
||||||
init_step_model = onnx.load(init_step_onnx_path)
|
|
||||||
stage_step_model = onnx.load(stage_step_onnx_path)
|
|
||||||
|
|
||||||
then_graph = init_step_model.graph
|
|
||||||
then_graph.name = "init_step_graph"
|
|
||||||
else_graph = stage_step_model.graph
|
|
||||||
else_graph.name = "stage_step_graph"
|
|
||||||
|
|
||||||
data_inputs_init = [input for input in init_step_model.graph.input]
|
|
||||||
data_inputs_stage = [input for input in stage_step_model.graph.input]
|
|
||||||
|
|
||||||
# Get all names from both lists
|
|
||||||
names_list_init = {obj.name for obj in data_inputs_init}
|
|
||||||
names_list_stage = {obj.name for obj in data_inputs_stage}
|
|
||||||
# Find names that appear in both lists
|
|
||||||
repeated_input_names = names_list_init.intersection(names_list_stage)
|
|
||||||
# Filter out objects with repeated names
|
|
||||||
data_inputs_stage = [obj for obj in data_inputs_stage if obj.name not in repeated_input_names]
|
|
||||||
|
|
||||||
del then_graph.input[:]
|
|
||||||
del else_graph.input[:]
|
|
||||||
|
|
||||||
# The output names of the subgraphs must be the same.
|
|
||||||
# The 'If' node will have an output with this same name.
|
|
||||||
subgraph_output_names = [output.name for output in then_graph.output]
|
|
||||||
for i, output in enumerate(else_graph.output):
|
|
||||||
assert subgraph_output_names[i] == output.name, "Subgraph output names must match"
|
|
||||||
|
|
||||||
# Define the inputs for the main graph
|
|
||||||
# 1. The boolean condition to select the branch
|
|
||||||
# cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, [])
|
|
||||||
|
|
||||||
main_outputs = [output for output in init_step_model.graph.output]
|
|
||||||
|
|
||||||
# Create the 'If' node
|
|
||||||
if_node = helper.make_node(
|
|
||||||
'If',
|
|
||||||
inputs=['if_init_step'],
|
|
||||||
outputs=subgraph_output_names, # This name MUST match the subgraph's output name
|
|
||||||
then_branch=then_graph,
|
|
||||||
else_branch=else_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
# Combine the models (this is a simplified example; actual combination logic may vary)
|
|
||||||
main_graph = helper.make_graph(
|
|
||||||
nodes=[if_node],
|
|
||||||
name="t2s_combined_graph",
|
|
||||||
inputs= data_inputs_init + data_inputs_stage,
|
|
||||||
outputs=main_outputs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the final combined model, specifying the opset and IR version
|
|
||||||
opset_version = 16
|
|
||||||
final_model = helper.make_model(main_graph,
|
|
||||||
producer_name='GSV-ONNX-Exporter',
|
|
||||||
ir_version=9, # For compatibility with older onnxruntime
|
|
||||||
opset_imports=[helper.make_opsetid("", opset_version)])
|
|
||||||
# Check the model for correctness
|
|
||||||
onnx.checker.check_model(final_model)
|
|
||||||
|
|
||||||
# Save the combined model
|
|
||||||
onnx.save(final_model, combined_onnx_path)
|
|
||||||
print(f"Combined model saved to {combined_onnx_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_combine=False, export_audio_preprocessor=True, half_precision=False):
|
|
||||||
vits = VitsModel(vits_path, version=voice_model_version)
|
vits = VitsModel(vits_path, version=voice_model_version)
|
||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
@ -453,12 +399,7 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com
|
|||||||
})
|
})
|
||||||
simplify_onnx_model(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
|
simplify_onnx_model(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
|
||||||
|
|
||||||
if t2s_model_combine:
|
|
||||||
combineInitStepAndStageStep(f'onnx/{project_name}/{project_name}_t2s_init_step.onnx', f'onnx/{project_name}/{project_name}_t2s_stage_step.onnx', f'onnx/{project_name}/{project_name}_t2s_combined.onnx')
|
|
||||||
|
|
||||||
if half_precision:
|
if half_precision:
|
||||||
if t2s_model_combine:
|
|
||||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_combined.onnx")
|
|
||||||
if export_audio_preprocessor:
|
if export_audio_preprocessor:
|
||||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
|
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
|
||||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx")
|
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx")
|
||||||
@ -467,7 +408,7 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com
|
|||||||
|
|
||||||
configJson = {
|
configJson = {
|
||||||
"project_name": project_name,
|
"project_name": project_name,
|
||||||
"type": "GPTSoVits",
|
"type": "GPTSoVITS",
|
||||||
"version" : voice_model_version,
|
"version" : voice_model_version,
|
||||||
"bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
|
"bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
|
||||||
"cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base',
|
"cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base',
|
||||||
@ -486,29 +427,29 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试
|
# 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
# exp_path = "v1_export"
|
exp_path = "v1_export"
|
||||||
# version = "v1"
|
version = "v1"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||||
exp_path = "v2_export"
|
exp_path = "v2_export"
|
||||||
version = "v2"
|
version = "v2"
|
||||||
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||||
# exp_path = "v2pro_export"
|
exp_path = "v2pro_export"
|
||||||
# version = "v2Pro"
|
version = "v2Pro"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
# exp_path = "v2proplus_export"
|
exp_path = "v2proplus_export"
|
||||||
# version = "v2ProPlus"
|
version = "v2ProPlus"
|
||||||
# export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATH = "onnx/v2_export/v2"
|
MODEL_PATH = "onnx/v1_export/v1"
|
||||||
|
|
||||||
def audio_postprocess(
|
def audio_postprocess(
|
||||||
audios,
|
audios,
|
||||||
@ -80,49 +80,52 @@ top_p = np.array([1.0], dtype=np.float32)
|
|||||||
repetition_penalty = np.array([1.0], dtype=np.float32)
|
repetition_penalty = np.array([1.0], dtype=np.float32)
|
||||||
temperature = np.array([1.0], dtype=np.float32)
|
temperature = np.array([1.0], dtype=np.float32)
|
||||||
|
|
||||||
t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx")
|
t2s_init_stage = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_stage.onnx")
|
||||||
# t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
# t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
||||||
|
|
||||||
[y, k, v, y_emb, x_example, fake_logits, fake_samples] = t2s_combined.run(None, {
|
[x, prompts, init_k, init_v, x_seq_len, y_seq_len] = t2s_init_stage.run(None, {
|
||||||
"if_init_step": np.array(True, dtype=bool),
|
|
||||||
"input_text_phones": input_phones,
|
"input_text_phones": input_phones,
|
||||||
"input_text_bert": input_bert,
|
"input_text_bert": input_bert,
|
||||||
"ref_text_phones": ref_phones,
|
"ref_text_phones": ref_phones,
|
||||||
"ref_text_bert": ref_bert,
|
"ref_text_bert": ref_bert,
|
||||||
"hubert_ssl_content": audio_prompt_hubert,
|
"hubert_ssl_content": audio_prompt_hubert,
|
||||||
"iy":np.empty((1, 0), dtype=np.int64),
|
})
|
||||||
"ik":np.empty((24, 0, 1, 512), dtype=np.float32),
|
empty_tensor = np.empty((1,0,512)).astype(np.float32)
|
||||||
"iv":np.empty((24, 0, 1, 512), dtype=np.float32),
|
|
||||||
"iy_emb":np.empty((1, 0, 512), dtype=np.float32),
|
t2s_stage_decoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_stage_decoder.onnx")
|
||||||
"ix_example":np.empty((1, 0), dtype=np.float32),
|
y, k, v, y_emb, logits, samples = t2s_stage_decoder.run(None, {
|
||||||
|
"ix": x,
|
||||||
|
"iy": prompts,
|
||||||
|
"ik": init_k,
|
||||||
|
"iv": init_v,
|
||||||
|
"iy_emb": empty_tensor,
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"repetition_penalty": repetition_penalty,
|
"repetition_penalty": repetition_penalty,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"if_init_step": np.array([True], dtype=bool)
|
"if_init_step": np.array([1]).astype(np.int64),
|
||||||
|
"x_seq_len": np.array([x_seq_len]).astype(np.int64),
|
||||||
|
"y_seq_len": np.array([y_seq_len]).astype(np.int64)
|
||||||
})
|
})
|
||||||
|
|
||||||
# t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
|
||||||
|
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
|
k = np.pad(k, ((0,0), (0,1), (0,0), (0,0)))
|
||||||
|
v = np.pad(v, ((0,0), (0,1), (0,0), (0,0)))
|
||||||
|
y_seq_len = np.array([y.shape[1]]).astype(np.int64)
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
[y, k, v, y_emb, fake_x_example, logits, samples] = t2s_combined.run(None, {
|
[y, k, v, y_emb, logits, samples] = t2s_stage_decoder.run(None, {
|
||||||
"if_init_step": np.array(False, dtype=bool),
|
"ix": empty_tensor,
|
||||||
"input_text_phones": np.empty((1, 0), dtype=np.int64),
|
|
||||||
"input_text_bert": np.empty((0, 1024), dtype=np.float32),
|
|
||||||
"ref_text_phones": np.empty((1, 0), dtype=np.int64),
|
|
||||||
"ref_text_bert": np.empty((0, 1024), dtype=np.float32),
|
|
||||||
"hubert_ssl_content": np.empty((1, 768, 0), dtype=np.float32),
|
|
||||||
"iy": y,
|
"iy": y,
|
||||||
"ik": k,
|
"ik": k,
|
||||||
"iv": v,
|
"iv": v,
|
||||||
"iy_emb": y_emb,
|
"iy_emb": y_emb,
|
||||||
"ix_example": x_example,
|
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"repetition_penalty": repetition_penalty,
|
"repetition_penalty": repetition_penalty,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"if_init_step": np.array([False], dtype=bool)
|
"if_init_step": np.array([0]).astype(np.int64),
|
||||||
|
"x_seq_len": np.array([x_seq_len]).astype(np.int64),
|
||||||
|
"y_seq_len": y_seq_len
|
||||||
})
|
})
|
||||||
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
|
||||||
break
|
break
|
||||||
|
Loading…
x
Reference in New Issue
Block a user