From 150cd842df8909af840ca1d812fb48dc2259eb60 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Thu, 29 May 2025 02:41:21 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=AF=BC=E5=87=BAv4?= =?UTF-8?q?=E7=9A=84script?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/export_torch_script_v3.py | 318 ++++++++++++++++++++++----- GPT_SoVITS/module/models_onnx.py | 3 +- 2 files changed, 270 insertions(+), 51 deletions(-) diff --git a/GPT_SoVITS/export_torch_script_v3.py b/GPT_SoVITS/export_torch_script_v3.py index b34495a7..dd8464be 100644 --- a/GPT_SoVITS/export_torch_script_v3.py +++ b/GPT_SoVITS/export_torch_script_v3.py @@ -10,7 +10,7 @@ from inference_webui import get_phones_and_bert import librosa from module import commons from module.mel_processing import mel_spectrogram_torch -from module.models_onnx import CFM, SynthesizerTrnV3 +from module.models_onnx import CFM, Generator, SynthesizerTrnV3 import numpy as np import torch._dynamo.config import torchaudio @@ -46,7 +46,7 @@ class MelSpectrgram(torch.nn.Module): center=False, ): super().__init__() - self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype) + self.hann_window = torch.hann_window(win_size).to(device=device, dtype=dtype) mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device) self.n_fft: int = n_fft @@ -189,6 +189,19 @@ mel_fn = lambda x: mel_spectrogram_torch( "center": False, }, ) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) spec_min = -12 spec_max = 2 @@ -285,6 +298,84 @@ class ExportGPTSovitsHalf(torch.nn.Module): return fea_ref, fea_todo, mel2 +class ExportGPTSovitsV4Half(torch.nn.Module): + def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3): + super().__init__() + self.hps = hps + self.t2s_m = t2s_m + self.vq_model = vq_model + self.mel2 = MelSpectrgram( + dtype=torch.float32, + device=device, + n_fft=1280, + num_mels=100, + sampling_rate=32000, + hop_size=320, + win_size=1280, + fmin=0, + fmax=None, + center=False, + ) + # self.dtype = dtype + self.filter_length: int = hps.data.filter_length + self.sampling_rate: int = hps.data.sampling_rate + self.hop_length: int = hps.data.hop_length + self.win_length: int = hps.data.win_length + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0, + phoneme_ids1, + bert1, + bert2, + top_k, + ): + refer = spectrogram_torch( + ref_audio_32k, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ).to(ssl_content.dtype) + + codes = self.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0) + # print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) + # print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + ge = self.vq_model.create_ge(refer) + # print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + prompt_ = prompt.unsqueeze(0) + fea_ref = self.vq_model(prompt_, phoneme_ids0, ge) + # print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(prompt_.shape, phoneme_ids0.shape, ge.shape) + # print(fea_ref.shape) + + ref_32k = ref_audio_32k + mel2 = norm_spec(self.mel2(ref_32k)).to(ssl_content.dtype) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if T_min > 500: + mel2 = mel2[:, :, -500:] + fea_ref = fea_ref[:, :, -500:] + T_min = 500 + + fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge) + # print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(pred_semantic.shape, phoneme_ids1.shape, ge.shape) + # print(fea_todo.shape) + + return fea_ref, fea_todo, mel2 + + class GPTSoVITSV3(torch.nn.Module): def __init__(self, gpt_sovits_half, cfm, bigvgan): super().__init__() @@ -311,6 +402,7 @@ class GPTSoVITSV3(torch.nn.Module): chunk_len = 934 - fea_ref.shape[2] wav_gen_list = [] idx = 0 + fea_todo = fea_todo[:,:,:-5] wav_gen_length = fea_todo.shape[2] * 256 while 1: # current_time = datetime.now() @@ -342,6 +434,65 @@ class GPTSoVITSV3(torch.nn.Module): wav_gen = torch.cat(wav_gen_list, 2) return wav_gen[0][0][:wav_gen_length] + +class GPTSoVITSV4(torch.nn.Module): + def __init__(self, gpt_sovits_half, cfm, hifigan): + super().__init__() + self.gpt_sovits_half = gpt_sovits_half + self.cfm = cfm + self.hifigan = hifigan + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0: torch.LongTensor, + phoneme_ids1: torch.LongTensor, + bert1, + bert2, + top_k: torch.LongTensor, + sample_steps: torch.LongTensor, + ): + # current_time = datetime.now() + # print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_ref, fea_todo, mel2 = self.gpt_sovits_half( + ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k + ) + chunk_len = 1000 - fea_ref.shape[2] + wav_gen_list = [] + idx = 0 + fea_todo = fea_todo[:,:,:-10] + wav_gen_length = fea_todo.shape[2] * 480 + while 1: + # current_time = datetime.now() + # print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + # 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样, + # 所以在这里补0让他shape维持不变 + # 但是这样会导致生成的音频长度不对,所以在最后截取一下。 + # 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256 + complete_len = chunk_len - fea_todo_chunk.shape[-1] + if complete_len != 0: + fea_todo_chunk = torch.cat( + [ + fea_todo_chunk, + torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype), + ], + 2, + ) + + cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps) + idx += chunk_len + + cfm_res = denorm_spec(cfm_res) + hifigan_res = self.hifigan(cfm_res) + wav_gen_list.append(hifigan_res) + + wav_gen = torch.cat(wav_gen_list, 2) + return wav_gen[0][0][:wav_gen_length] def init_bigvgan(): @@ -361,6 +512,31 @@ def init_bigvgan(): bigvgan_model = bigvgan_model.to(device) +def init_hifigan(): + global hifigan_model, bigvgan_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, + is_bias=True, + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load( + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu" + ) + print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) + + class Sovits: def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps): self.vq_model = vq_model @@ -399,6 +575,7 @@ class DictToAttrRecursive(dict): from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new +v3v4set = {"v3", "v4"} def get_sovits_weights(sovits_path): path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" @@ -419,8 +596,8 @@ def get_sovits_weights(sovits_path): else: hps.model.version = "v2" - if model_version == "v3": - hps.model.version = "v3" + if model_version in v3v4set: + hps.model.version = model_version logger.info(f"hps: {hps}") @@ -522,10 +699,14 @@ def export_cfm( return export_cfm -def export(): - sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") - - init_bigvgan() +def export(version="v3"): + if version == "v3": + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") + init_bigvgan() + else: + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth") + init_hifigan() + dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt") raw_t2s = get_raw_t2s_model(dict_s1).to(device) @@ -542,7 +723,7 @@ def export(): hps = sovits.hps ref_wav_path = "onnx/ad/ref.wav" speed = 1.0 - sample_steps = 32 + sample_steps = 8 dtype = torch.float16 if is_half == True else torch.float32 refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) zero_wav = np.zeros( @@ -634,25 +815,33 @@ def export(): # vq_model = sovits.vq_model vq_model = trace_vq_model - gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model) - torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt") + if version == "v3": + gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model) + torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt") + else: + gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model) + torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt") ref_audio, sr = torchaudio.load(ref_wav_path) ref_audio = ref_audio.to(device).float() if ref_audio.shape[0] == 2: ref_audio = ref_audio.mean(0).unsqueeze(0) - if sr != 24000: - ref_audio = resample(ref_audio, sr) + tgt_sr = 24000 if version == "v3" else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr, tgt_sr) # mel2 = mel_fn(ref_audio) - mel2 = norm_spec(mel_fn(ref_audio)) + mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) fea_ref = fea_ref[:, :, :T_min] print("fea_ref:", fea_ref.shape, T_min) - if T_min > 468: - mel2 = mel2[:, :, -468:] - fea_ref = fea_ref[:, :, -468:] - T_min = 468 - chunk_len = 934 - T_min + Tref = 468 if version == "v3" else 500 + Tchunk = 934 if version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min mel2 = mel2.to(dtype) # fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge) @@ -714,13 +903,19 @@ def export(): with torch.inference_mode(): cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype) torch._dynamo.mark_dynamic(cmf_res_rand, 2) - bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)) - bigvgan_model_.save("onnx/ad/bigvgan_model.pt") - wav_gen = bigvgan_model(cmf_res) + if version == "v3": + bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)) + bigvgan_model_.save("onnx/ad/bigvgan_model.pt") + wav_gen = bigvgan_model(cmf_res) + else: + hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,)) + hifigan_model_.save("onnx/ad/hifigan_model.pt") + wav_gen = hifigan_model(cmf_res) + print("wav_gen:", wav_gen.shape, wav_gen.dtype) audio = wav_gen[0][0].cpu().detach().numpy() - sr = 24000 + sr = 24000 if version == "v3" else 48000 soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr) @@ -848,8 +1043,9 @@ def test_export( def test_export1( todo_text, - gpt_sovits_v3, + gpt_sovits_v3v4, output, + out_sr=24000, ): # hps = sovits.hps ref_wav_path = "onnx/ad/ref.wav" @@ -859,7 +1055,7 @@ def test_export1( dtype = torch.float16 if is_half == True else torch.float32 zero_wav = np.zeros( - int(24000 * 0.3), + int(out_sr * 0.3), dtype=np.float16 if is_half == True else np.float32, ) @@ -907,22 +1103,26 @@ def test_export1( bert2.shape, top_k.shape, ) - wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps) + wav_gen = gpt_sovits_v3v4(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps) print("wav_gen:", wav_gen.shape, wav_gen.dtype) wav_gen = torch.cat([wav_gen, zero_wav_torch], 0) audio = wav_gen.cpu().detach().numpy() logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) - sr = 24000 - soundfile.write(output, (audio * 32768).astype(np.int16), sr) + soundfile.write(output, (audio * 32768).astype(np.int16), out_sr) import time -def test_(): - sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") +def test_(version="v3"): + if version == "v3": + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") + # init_bigvgan() + else: + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth") + # init_hifigan() # cfm = ExportCFM(sovits.cfm) # cfm.cfm.estimator = dit @@ -963,25 +1163,41 @@ def test_(): # gpt_sovits_v3_half = gpt_sovits_v3_half.half() # gpt_sovits_v3_half = gpt_sovits_v3_half.cuda() # gpt_sovits_v3_half.eval() - gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model) - logger.info("gpt_sovits_v3_half ok") + if version == "v3": + gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model) + logger.info("gpt_sovits_v3_half ok") + # init_bigvgan() + # global bigvgan_model + bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt") + # bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model) + bigvgan_model = bigvgan_model.half() + bigvgan_model = bigvgan_model.cuda() + bigvgan_model.eval() - # init_bigvgan() - # global bigvgan_model - bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt") - # bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model) - bigvgan_model = bigvgan_model.half() - bigvgan_model = bigvgan_model.cuda() - bigvgan_model.eval() + logger.info("bigvgan ok") + gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model) + gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3) + gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt") + gpt_sovits_v3 = gpt_sovits_v3.half().to(device) + gpt_sovits_v3.eval() + print("save gpt_sovits_v3 ok") + else: + gpt_sovits_v4_half = ExportGPTSovitsV4Half(sovits.hps, t2s_m, vq_model) + logger.info("gpt_sovits_v4 ok") - logger.info("bigvgan ok") + hifigan_model = torch.jit.load("onnx/ad/hifigan_model.pt") + hifigan_model = hifigan_model.half() + hifigan_model = hifigan_model.cuda() + hifigan_model.eval() + logger.info("hifigan ok") + gpt_sovits_v4 = GPTSoVITSV4(gpt_sovits_v4_half, cfm, hifigan_model) + gpt_sovits_v4 = torch.jit.script(gpt_sovits_v4) + gpt_sovits_v4.save("onnx/ad/gpt_sovits_v4.pt") + print("save gpt_sovits_v4 ok") + + gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4 + sr = 24000 if version == "v3" else 48000 - gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model) - gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3) - gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt") - gpt_sovits_v3 = gpt_sovits_v3.half().to(device) - gpt_sovits_v3.eval() - print("save gpt_sovits_v3 ok") time.sleep(5) # print("thread:", torch.get_num_threads()) @@ -991,14 +1207,16 @@ def test_(): test_export1( "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", - gpt_sovits_v3, + gpt_sovits_v3v4, "out.wav", + sr ) test_export1( "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!", - gpt_sovits_v3, + gpt_sovits_v3v4, "out2.wav", + sr ) # test_export( @@ -1030,6 +1248,6 @@ def test_export_gpt_sovits_v3(): with torch.no_grad(): - # export() - test_() + export("v4") + # test_("v4") # test_export_gpt_sovits_v3() diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index 8a3ad13f..028db5f5 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -391,6 +391,7 @@ class Generator(torch.nn.Module): upsample_initial_channel, upsample_kernel_sizes, gin_channels=0, + is_bias=False, ): super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) @@ -418,7 +419,7 @@ class Generator(torch.nn.Module): for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock(ch, k, d)) - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias) self.ups.apply(init_weights) if gin_channels != 0: