mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Merge pull request #438 from NaruseMioShirakana/main
Update Onnx Export
This commit is contained in:
commit
f02aab9892
@ -57,7 +57,7 @@ def logits_to_probs(
|
|||||||
logits = logits / max(temperature, 1e-5)
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
v, _ = torch.topk(logits, top_k)
|
||||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||||
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
||||||
|
|
||||||
|
@ -188,38 +188,27 @@ class MultiHeadAttention(nn.Module):
|
|||||||
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||||
key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||||
value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||||
|
|
||||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||||
|
|
||||||
if self.window_size is not None:
|
if self.window_size is not None:
|
||||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||||
rel_logits = self._matmul_with_relative_keys(
|
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
||||||
query / math.sqrt(self.k_channels), key_relative_embeddings
|
|
||||||
)
|
|
||||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||||
scores = scores + scores_local
|
scores = scores + scores_local
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores.masked_fill(mask == 0, -1e4)
|
scores = scores.masked_fill(mask == 0, -1e4)
|
||||||
if self.block_length is not None:
|
|
||||||
block_mask = (
|
|
||||||
torch.ones_like(scores)
|
|
||||||
.triu(-self.block_length)
|
|
||||||
.tril(self.block_length)
|
|
||||||
)
|
|
||||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
|
||||||
p_attn = F.softmax(scores, dim=-1)
|
p_attn = F.softmax(scores, dim=-1)
|
||||||
p_attn = self.drop(p_attn)
|
p_attn = self.drop(p_attn)
|
||||||
output = torch.matmul(p_attn, value)
|
output = torch.matmul(p_attn, value)
|
||||||
|
|
||||||
if self.window_size is not None:
|
if self.window_size is not None:
|
||||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||||
value_relative_embeddings = self._get_relative_embeddings(
|
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||||
self.emb_rel_v, t_s
|
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||||
)
|
|
||||||
output = output + self._matmul_with_relative_values(
|
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
|
||||||
relative_weights, value_relative_embeddings
|
|
||||||
)
|
|
||||||
output = (
|
|
||||||
output.transpose(2, 3).contiguous().view(b, d, -1)
|
|
||||||
)
|
|
||||||
return output, p_attn
|
return output, p_attn
|
||||||
|
|
||||||
def _matmul_with_relative_values(self, x, y):
|
def _matmul_with_relative_values(self, x, y):
|
||||||
@ -243,16 +232,16 @@ class MultiHeadAttention(nn.Module):
|
|||||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||||
max_relative_position = 2 * self.window_size + 1
|
max_relative_position = 2 * self.window_size + 1
|
||||||
# Pad first before slice to avoid using cond ops.
|
# Pad first before slice to avoid using cond ops.
|
||||||
pad_length = max(length - (self.window_size + 1), 0)
|
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
|
||||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
|
||||||
|
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
|
||||||
|
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
|
||||||
|
|
||||||
slice_end_position = slice_start_position + 2 * length - 1
|
slice_end_position = slice_start_position + 2 * length - 1
|
||||||
if pad_length > 0:
|
padded_relative_embeddings = F.pad(
|
||||||
padded_relative_embeddings = F.pad(
|
relative_embeddings,
|
||||||
relative_embeddings,
|
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
padded_relative_embeddings = relative_embeddings
|
|
||||||
used_relative_embeddings = padded_relative_embeddings[
|
used_relative_embeddings = padded_relative_embeddings[
|
||||||
:, slice_start_position:slice_end_position
|
:, slice_start_position:slice_end_position
|
||||||
]
|
]
|
||||||
|
@ -896,9 +896,6 @@ class SynthesizerTrn(nn.Module):
|
|||||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
|
||||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
|
||||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
|
||||||
|
|
||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||||
@ -907,6 +904,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
x, m_p, logs_p, y_mask = self.enc_p(
|
x, m_p, logs_p, y_mask = self.enc_p(
|
||||||
quantized, text, ge
|
quantized, text, ge
|
||||||
)
|
)
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
@ -140,6 +140,7 @@ class T2SModel(nn.Module):
|
|||||||
)
|
)
|
||||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||||
return
|
return
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.onnx_encoder,
|
self.onnx_encoder,
|
||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||||
@ -147,16 +148,16 @@ class T2SModel(nn.Module):
|
|||||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||||
output_names=["x", "prompts"],
|
output_names=["x", "prompts"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"ref_seq": [1],
|
"ref_seq": {1 : "ref_length"},
|
||||||
"text_seq": [1],
|
"text_seq": {1 : "text_length"},
|
||||||
"ref_bert": [0],
|
"ref_bert": {0 : "ref_length"},
|
||||||
"text_bert": [0],
|
"text_bert": {0 : "text_length"},
|
||||||
"ssl_content": [2],
|
"ssl_content": {2 : "ssl_length"},
|
||||||
},
|
},
|
||||||
opset_version=16
|
opset_version=16
|
||||||
)
|
)
|
||||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
torch.exp
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.first_stage_decoder,
|
self.first_stage_decoder,
|
||||||
(x, prompts),
|
(x, prompts),
|
||||||
@ -164,10 +165,10 @@ class T2SModel(nn.Module):
|
|||||||
input_names=["x", "prompts"],
|
input_names=["x", "prompts"],
|
||||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"x": [1],
|
"x": {1 : "x_length"},
|
||||||
"prompts": [1],
|
"prompts": {1 : "prompts_length"},
|
||||||
},
|
},
|
||||||
verbose=True,
|
verbose=False,
|
||||||
opset_version=16
|
opset_version=16
|
||||||
)
|
)
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
@ -179,13 +180,13 @@ class T2SModel(nn.Module):
|
|||||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"iy": [1],
|
"iy": {1 : "iy_length"},
|
||||||
"ik": [1],
|
"ik": {1 : "ik_length"},
|
||||||
"iv": [1],
|
"iv": {1 : "iv_length"},
|
||||||
"iy_emb": [1],
|
"iy_emb": {1 : "iy_emb_length"},
|
||||||
"ix_example": [1],
|
"ix_example": {1 : "ix_example_length"},
|
||||||
},
|
},
|
||||||
verbose=True,
|
verbose=False,
|
||||||
opset_version=16
|
opset_version=16
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,9 +225,19 @@ class GptSoVits(nn.Module):
|
|||||||
self.vits = vits
|
self.vits = vits
|
||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
return self.vits(text_seq, pred_semantic, ref_audio)
|
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||||
|
if debug:
|
||||||
|
import onnxruntime
|
||||||
|
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||||
|
audio1 = sess.run(None, {
|
||||||
|
"text_seq" : text_seq.detach().cpu().numpy(),
|
||||||
|
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
||||||
|
"ref_audio" : ref_audio.detach().cpu().numpy()
|
||||||
|
})
|
||||||
|
return audio, audio1
|
||||||
|
return audio
|
||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
||||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
||||||
@ -238,11 +249,12 @@ class GptSoVits(nn.Module):
|
|||||||
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
||||||
output_names=["audio"],
|
output_names=["audio"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"text_seq": [1],
|
"text_seq": {1 : "text_length"},
|
||||||
"pred_semantic": [2],
|
"pred_semantic": {2 : "pred_length"},
|
||||||
"ref_audio": [1],
|
"ref_audio": {1 : "audio_length"},
|
||||||
},
|
},
|
||||||
opset_version=17
|
opset_version=17,
|
||||||
|
verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -261,7 +273,7 @@ def export(vits_path, gpt_path, project_name):
|
|||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
ssl = SSLModel()
|
ssl = SSLModel()
|
||||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||||
@ -275,10 +287,18 @@ def export(vits_path, gpt_path, project_name):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
ssl_content = ssl(ref_audio_16k).float()
|
ssl_content = ssl(ref_audio_16k).float()
|
||||||
|
|
||||||
|
debug = False
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||||
|
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
||||||
|
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
||||||
|
return
|
||||||
|
|
||||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||||
|
|
||||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||||
|
|
||||||
@ -306,9 +326,9 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
gpt_path = "pt_model/koharu-e20.ckpt"
|
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
||||||
vits_path = "pt_model/koharu_e20_s4960.pth"
|
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
||||||
exp_path = "koharu"
|
exp_path = "nahida"
|
||||||
export(vits_path, gpt_path, exp_path)
|
export(vits_path, gpt_path, exp_path)
|
||||||
|
|
||||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
Loading…
x
Reference in New Issue
Block a user