From 7dec5f5bb0395a7ae752e2a15a246a62c03bbc8d Mon Sep 17 00:00:00 2001 From: zzz <458761603@qq.com> Date: Fri, 13 Jun 2025 22:10:11 +0800 Subject: [PATCH] Merge pull request #2460 from L-jasmine/export_v2pro MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 优化 torch_script 导出模型 --- GPT_SoVITS/export_torch_script.py | 31 +++++++++++++++----------- GPT_SoVITS/export_torch_script_v3v4.py | 10 ++++++--- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index 6a13c2d..bf32ed6 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -103,7 +103,7 @@ def logits_to_probs( @torch.jit.script def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization - q = torch.randn_like(probs_sort) + q = torch.empty_like(probs_sort).exponential_(1.0) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) @@ -114,7 +114,7 @@ def sample( temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[int] = None, - repetition_penalty: float = 1.0, + repetition_penalty: float = 1.35, ): probs = logits_to_probs( logits=logits, @@ -129,8 +129,8 @@ def sample( @torch.jit.script -def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False): - hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype) +def spectrogram_torch(hann_window:Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False): + # hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype) y = torch.nn.functional.pad( y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), @@ -309,8 +309,9 @@ class T2SBlock: attn = F.scaled_dot_product_attention(q, k, v) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(attn, self.out_w, self.out_b) x = x + attn @@ -348,7 +349,7 @@ class T2STransformer: class VitsModel(nn.Module): - def __init__(self, vits_path, version=None): + def __init__(self, vits_path, version=None, is_half=True, device="cpu"): super().__init__() # dict_s2 = torch.load(vits_path,map_location="cpu") dict_s2 = load_sovits_new(vits_path) @@ -373,11 +374,18 @@ class VitsModel(nn.Module): n_speakers=self.hps.data.n_speakers, **self.hps.model, ) - self.vq_model.eval() self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + self.vq_model.dec.remove_weight_norm() + if is_half: + self.vq_model = self.vq_model.half() + self.vq_model = self.vq_model.to(device) + self.vq_model.eval() + self.hann_window = torch.hann_window(self.hps.data.win_length, device=device, dtype= torch.float16 if is_half else torch.float32) + def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None): refer = spectrogram_torch( + self.hann_window, ref_audio, self.hps.data.filter_length, self.hps.data.sampling_rate, @@ -667,7 +675,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be ssl_content = ssl(ref_audio).to(device) # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - vits = VitsModel(vits_path).to(device) + vits = VitsModel(vits_path,device=device,is_half=False) vits.eval() # gpt_path = "GPT_weights_v2/xw-e15.ckpt" @@ -765,10 +773,7 @@ def export_prov2( sv_model = ExportERes2NetV2(sv_cn_model) # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - vits = VitsModel(vits_path, version) - if is_half: - vits.vq_model = vits.vq_model.half() - vits.to(device) + vits = VitsModel(vits_path, version,is_half=is_half,device=device) vits.eval() # gpt_path = "GPT_weights_v2/xw-e15.ckpt" diff --git a/GPT_SoVITS/export_torch_script_v3v4.py b/GPT_SoVITS/export_torch_script_v3v4.py index 55d2728..1fd63b3 100644 --- a/GPT_SoVITS/export_torch_script_v3v4.py +++ b/GPT_SoVITS/export_torch_script_v3v4.py @@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module): self.sampling_rate: int = hps.data.sampling_rate self.hop_length: int = hps.data.hop_length self.win_length: int = hps.data.win_length + self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32) def forward( self, @@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module): top_k, ): refer = spectrogram_torch( + self.hann_window, ref_audio_32k, self.filter_length, self.sampling_rate, @@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module): self.sampling_rate: int = hps.data.sampling_rate self.hop_length: int = hps.data.hop_length self.win_length: int = hps.data.win_length + self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32) def forward( self, @@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module): top_k, ): refer = spectrogram_torch( + self.hann_window, ref_audio_32k, self.filter_length, self.sampling_rate, @@ -1149,7 +1153,7 @@ def export_2(version="v3"): raw_t2s = raw_t2s.half().to(device) t2s_m = T2SModel(raw_t2s).half().to(device) t2s_m.eval() - t2s_m = torch.jit.script(t2s_m) + t2s_m = torch.jit.script(t2s_m).to(device) t2s_m.eval() # t2s_m.top_k = 15 logger.info("t2s_m ok") @@ -1251,6 +1255,6 @@ def test_export_gpt_sovits_v3(): with torch.no_grad(): - export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4") - # export_2("v4") + # export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4") + export_2("v4") # test_export_gpt_sovits_v3()