diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index b62b8b71..9ee78d60 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -205,6 +205,8 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, y, text, ge, speed=1): + if type(speed) == float: + speed = torch.FloatTensor([speed]) y_mask = torch.ones_like(y[:1, :1, :]) y = self.ssl_proj(y * y_mask) * y_mask @@ -217,9 +219,8 @@ class TextEncoder(nn.Module): y = self.mrte(y, y_mask, text, text_mask, ge) y = self.encoder2(y * y_mask, y_mask) - if speed != 1: - y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") - y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") + y = F.interpolate(y, size=(y.shape[-1] / speed).to(torch.int) + 1, mode="linear") + y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") stats = self.proj(y) * y_mask m, logs = torch.split(stats, self.out_channels, dim=1) diff --git a/GPT_SoVITS/onnx_export_v1v2.py b/GPT_SoVITS/onnx_export_v1v2.py index d2d444e1..45f4679b 100644 --- a/GPT_SoVITS/onnx_export_v1v2.py +++ b/GPT_SoVITS/onnx_export_v1v2.py @@ -233,7 +233,7 @@ class VitsModel(nn.Module): self.hps = DictToAttrRecursive(self.hps) self.hps.model.semantic_frame_rate = "25hz" - self.vq_model = SynthesizerTrn( + self.vq_model:SynthesizerTrn = SynthesizerTrn( self.hps.data.filter_length // 2 + 1, self.hps.train.segment_size // self.hps.data.hop_length, n_speakers=self.hps.data.n_speakers, @@ -243,32 +243,27 @@ class VitsModel(nn.Module): self.vq_model.load_state_dict(dict_s2["weight"], strict=False) # print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}") #v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048 - def forward(self, text_seq, pred_semantic, spectrum, sv_emb): + def forward(self, text_seq, pred_semantic, spectrum, sv_emb, speed): if self.is_v2p: - return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0] + return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb, speed=speed)[0, 0] else: - return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0] + return self.vq_model(pred_semantic, text_seq, spectrum, speed=speed)[0, 0] -class GptSoVits(nn.Module): +class GptSoVits(): def __init__(self, vits, t2s): super().__init__() self.vits = vits self.t2s = t2s - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, top_k=None, top_p=None, repetition_penalty=None, temperature=None): - pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) - audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb) - return audio - - def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, speed, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) torch.onnx.export( self.vits, - (text_seq, pred_semantic, spectrum, sv_emb), + (text_seq, pred_semantic, spectrum, sv_emb, speed), f"onnx/{project_name}/{project_name}_vits.onnx", - input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb"], + input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb", "speed"], output_names=["audio"], dynamic_axes={ "input_text_phones": {1: "text_length"}, @@ -379,13 +374,12 @@ def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_ top_p = torch.FloatTensor([1.0]) repetition_penalty = torch.FloatTensor([1.0]) temperature = torch.FloatTensor([1.0]) + speed = torch.FloatTensor([1.0]) os.makedirs(f"onnx/{project_name}", exist_ok=True) [ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k) - gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) - # exit() - gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), speed, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) if export_audio_preprocessor: torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",