mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
update init_step name
This commit is contained in:
parent
aafa0561d8
commit
4e0cc57052
@ -89,7 +89,7 @@ class DictToAttrRecursive(dict):
|
|||||||
raise AttributeError(f"Attribute {item} not found")
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
|
||||||
class T2SEncoder(nn.Module):
|
class T2SInitStep(nn.Module):
|
||||||
def __init__(self, t2s, vits):
|
def __init__(self, t2s, vits):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = t2s.onnx_encoder
|
self.encoder = t2s.onnx_encoder
|
||||||
@ -122,7 +122,7 @@ class T2SModel(nn.Module):
|
|||||||
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||||
self.t2s_model = self.t2s_model.model
|
self.t2s_model = self.t2s_model.model
|
||||||
self.t2s_model.init_onnx()
|
self.t2s_model.init_onnx()
|
||||||
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
self.init_step = T2SInitStep(self.t2s_model, self.vits_model)
|
||||||
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
||||||
self.stage_decoder = self.t2s_model.stage_decoder
|
self.stage_decoder = self.t2s_model.stage_decoder
|
||||||
# self.t2s_model = torch.jit.script(self.t2s_model)
|
# self.t2s_model = torch.jit.script(self.t2s_model)
|
||||||
@ -131,7 +131,7 @@ class T2SModel(nn.Module):
|
|||||||
early_stop_num = self.t2s_model.early_stop_num
|
early_stop_num = self.t2s_model.early_stop_num
|
||||||
|
|
||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
@ -144,19 +144,19 @@ class T2SModel(nn.Module):
|
|||||||
return y[:, -idx:].unsqueeze(0)
|
return y[:, -idx:].unsqueeze(0)
|
||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
||||||
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
# self.init_step = torch.jit.script(self.init_step)
|
||||||
if dynamo:
|
if dynamo:
|
||||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||||
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
init_step_export_output = torch.onnx.dynamo_export(
|
||||||
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
self.init_step, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
||||||
)
|
)
|
||||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
init_step_export_output.save(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
||||||
return
|
return
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.onnx_encoder,
|
self.init_step,
|
||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
||||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
@ -168,7 +168,7 @@ class T2SModel(nn.Module):
|
|||||||
},
|
},
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
)
|
)
|
||||||
y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
# torch.onnx.export(
|
# torch.onnx.export(
|
||||||
# self.first_stage_decoder,
|
# self.first_stage_decoder,
|
||||||
|
@ -63,7 +63,7 @@ def preprocess_text(text:str):
|
|||||||
|
|
||||||
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
||||||
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
||||||
[input_phones, input_bert] = preprocess_text("像大雨匆匆打击过的屋檐")
|
[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追")
|
||||||
|
|
||||||
|
|
||||||
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
||||||
@ -74,9 +74,9 @@ def preprocess_text(text:str):
|
|||||||
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
||||||
|
|
||||||
|
|
||||||
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
||||||
|
|
||||||
[y, k, v, y_emb, x_example] = encoder.run(None, {
|
[y, k, v, y_emb, x_example] = init_step.run(None, {
|
||||||
"text_seq": input_phones,
|
"text_seq": input_phones,
|
||||||
"text_bert": input_bert,
|
"text_bert": input_bert,
|
||||||
"ref_seq": ref_phones,
|
"ref_seq": ref_phones,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user