perf(export_torch_script): 缓存 Vits 中用到的 hann_window

This commit is contained in:
csh 2025-06-13 17:33:20 +08:00
parent 254b9b0b55
commit 6feafac1df
2 changed files with 20 additions and 12 deletions

View File

@ -129,8 +129,8 @@ def sample(
@torch.jit.script @torch.jit.script
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False): def spectrogram_torch(hann_window:Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype) # hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@ -349,7 +349,7 @@ class T2STransformer:
class VitsModel(nn.Module): class VitsModel(nn.Module):
def __init__(self, vits_path, version=None): def __init__(self, vits_path, version=None, is_half=True, device="cpu"):
super().__init__() super().__init__()
# dict_s2 = torch.load(vits_path,map_location="cpu") # dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = load_sovits_new(vits_path) dict_s2 = load_sovits_new(vits_path)
@ -374,11 +374,18 @@ class VitsModel(nn.Module):
n_speakers=self.hps.data.n_speakers, n_speakers=self.hps.data.n_speakers,
**self.hps.model, **self.hps.model,
) )
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False) self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
self.vq_model.dec.remove_weight_norm()
if is_half:
self.vq_model = self.vq_model.half()
self.vq_model = self.vq_model.to(device)
self.vq_model.eval()
self.hann_window = torch.hann_window(self.hps.data.win_length, device=device, dtype= torch.float16 if is_half else torch.float32)
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None): def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
refer = spectrogram_torch( refer = spectrogram_torch(
self.hann_window,
ref_audio, ref_audio,
self.hps.data.filter_length, self.hps.data.filter_length,
self.hps.data.sampling_rate, self.hps.data.sampling_rate,
@ -668,7 +675,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ssl_content = ssl(ref_audio).to(device) ssl_content = ssl(ref_audio).to(device)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path).to(device) vits = VitsModel(vits_path,device=device,is_half=False)
vits.eval() vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt" # gpt_path = "GPT_weights_v2/xw-e15.ckpt"
@ -766,10 +773,7 @@ def export_prov2(
sv_model = ExportERes2NetV2(sv_cn_model) sv_model = ExportERes2NetV2(sv_cn_model)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path, version) vits = VitsModel(vits_path, version,is_half=is_half,device=device)
if is_half:
vits.vq_model = vits.vq_model.half()
vits.to(device)
vits.eval() vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt" # gpt_path = "GPT_weights_v2/xw-e15.ckpt"

View File

@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward( def forward(
self, self,
@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
top_k, top_k,
): ):
refer = spectrogram_torch( refer = spectrogram_torch(
self.hann_window,
ref_audio_32k, ref_audio_32k,
self.filter_length, self.filter_length,
self.sampling_rate, self.sampling_rate,
@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward( def forward(
self, self,
@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
top_k, top_k,
): ):
refer = spectrogram_torch( refer = spectrogram_torch(
self.hann_window,
ref_audio_32k, ref_audio_32k,
self.filter_length, self.filter_length,
self.sampling_rate, self.sampling_rate,
@ -1149,7 +1153,7 @@ def export_2(version="v3"):
raw_t2s = raw_t2s.half().to(device) raw_t2s = raw_t2s.half().to(device)
t2s_m = T2SModel(raw_t2s).half().to(device) t2s_m = T2SModel(raw_t2s).half().to(device)
t2s_m.eval() t2s_m.eval()
t2s_m = torch.jit.script(t2s_m) t2s_m = torch.jit.script(t2s_m).to(device)
t2s_m.eval() t2s_m.eval()
# t2s_m.top_k = 15 # t2s_m.top_k = 15
logger.info("t2s_m ok") logger.info("t2s_m ok")
@ -1251,6 +1255,6 @@ def test_export_gpt_sovits_v3():
with torch.no_grad(): with torch.no_grad():
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4") # export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
# export_2("v4") export_2("v4")
# test_export_gpt_sovits_v3() # test_export_gpt_sovits_v3()