mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
Add files via upload
This commit is contained in:
parent
57dde51e11
commit
8c01e275ec
@ -140,6 +140,7 @@ class T2SModel(nn.Module):
|
||||
)
|
||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||
return
|
||||
|
||||
torch.onnx.export(
|
||||
self.onnx_encoder,
|
||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||
@ -147,16 +148,16 @@ class T2SModel(nn.Module):
|
||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||
output_names=["x", "prompts"],
|
||||
dynamic_axes={
|
||||
"ref_seq": [1],
|
||||
"text_seq": [1],
|
||||
"ref_bert": [0],
|
||||
"text_bert": [0],
|
||||
"ssl_content": [2],
|
||||
"ref_seq": {1 : "ref_length"},
|
||||
"text_seq": {1 : "text_length"},
|
||||
"ref_bert": {0 : "ref_length"},
|
||||
"text_bert": {0 : "text_length"},
|
||||
"ssl_content": {2 : "ssl_length"},
|
||||
},
|
||||
opset_version=16
|
||||
)
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
torch.exp
|
||||
|
||||
torch.onnx.export(
|
||||
self.first_stage_decoder,
|
||||
(x, prompts),
|
||||
@ -164,10 +165,10 @@ class T2SModel(nn.Module):
|
||||
input_names=["x", "prompts"],
|
||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||
dynamic_axes={
|
||||
"x": [1],
|
||||
"prompts": [1],
|
||||
"x": {1 : "x_length"},
|
||||
"prompts": {1 : "prompts_length"},
|
||||
},
|
||||
verbose=True,
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
@ -179,13 +180,13 @@ class T2SModel(nn.Module):
|
||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||
dynamic_axes={
|
||||
"iy": [1],
|
||||
"ik": [1],
|
||||
"iv": [1],
|
||||
"iy_emb": [1],
|
||||
"ix_example": [1],
|
||||
"iy": {1 : "iy_length"},
|
||||
"ik": {1 : "ik_length"},
|
||||
"iv": {1 : "iv_length"},
|
||||
"iy_emb": {1 : "iy_emb_length"},
|
||||
"ix_example": {1 : "ix_example_length"},
|
||||
},
|
||||
verbose=True,
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
)
|
||||
|
||||
@ -224,9 +225,19 @@ class GptSoVits(nn.Module):
|
||||
self.vits = vits
|
||||
self.t2s = t2s
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
return self.vits(text_seq, pred_semantic, ref_audio)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||
if debug:
|
||||
import onnxruntime
|
||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||
audio1 = sess.run(None, {
|
||||
"text_seq" : text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio" : ref_audio.detach().cpu().numpy()
|
||||
})
|
||||
return audio, audio1
|
||||
return audio
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
||||
@ -238,11 +249,12 @@ class GptSoVits(nn.Module):
|
||||
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"text_seq": [1],
|
||||
"pred_semantic": [2],
|
||||
"ref_audio": [1],
|
||||
"text_seq": {1 : "text_length"},
|
||||
"pred_semantic": {2 : "pred_length"},
|
||||
"ref_audio": {1 : "audio_length"},
|
||||
},
|
||||
opset_version=17
|
||||
opset_version=17,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
@ -261,7 +273,7 @@ def export(vits_path, gpt_path, project_name):
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
ssl = SSLModel()
|
||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||
@ -275,10 +287,18 @@ def export(vits_path, gpt_path, project_name):
|
||||
pass
|
||||
|
||||
ssl_content = ssl(ref_audio_16k).float()
|
||||
|
||||
debug = False
|
||||
|
||||
if debug:
|
||||
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
||||
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
||||
return
|
||||
|
||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||
|
||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||
|
||||
@ -306,9 +326,9 @@ if __name__ == "__main__":
|
||||
except:
|
||||
pass
|
||||
|
||||
gpt_path = "pt_model/koharu-e20.ckpt"
|
||||
vits_path = "pt_model/koharu_e20_s4960.pth"
|
||||
exp_path = "koharu"
|
||||
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
||||
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
||||
exp_path = "nahida"
|
||||
export(vits_path, gpt_path, exp_path)
|
||||
|
||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
Loading…
x
Reference in New Issue
Block a user