From ffb520ee54fafc83db3aa13db327266bab63904c Mon Sep 17 00:00:00 2001 From: Karasukaigan <80465610+Karasukaigan@users.noreply.github.com> Date: Fri, 9 May 2025 20:14:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84api.py=E5=AF=B9=E4=BA=8Ev4?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=9A=84=E5=85=BC=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 完善api.py对于v4模型的兼容。 --- api.py | 92 +++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 24 deletions(-) diff --git a/api.py b/api.py index aa9a6668..71158c5c 100644 --- a/api.py +++ b/api.py @@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np from feature_extractor import cnhubert from io import BytesIO -from module.models import SynthesizerTrn, SynthesizerTrnV3 +from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator from peft import LoraConfig, get_peft_model from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence @@ -214,6 +214,38 @@ def init_bigvgan(): else: bigvgan_model = bigvgan_model.to(device) +def init_vocoder(version: str): + global bigvgan_model + from BigVGAN import bigvgan + + if version == "v3": + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + + elif version == "v4": + bigvgan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, is_bias=True + ) + bigvgan_model.remove_weight_norm() + state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu") + bigvgan_model.load_state_dict(state_dict_g) + + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) resample_transform_dict = {} @@ -253,6 +285,20 @@ mel_fn = lambda x: mel_spectrogram_torch( }, ) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + sr_model = None @@ -297,10 +343,8 @@ def get_sovits_weights(sovits_path): path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth" version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) - if if_lora_v3 == True and not os.path.exists(path_sovits_v3): - logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") - if model_version == "v4" and not os.path.exists(path_sovits_v4): - logger.info("SoVITS V4 底模缺失,无法加载相应 LoRA 权重") + if (if_lora_v3 == True and not os.path.exists(path_sovits_v3)) or (model_version == "v4" and not os.path.exists(path_sovits_v4)): + logger.info(f"SoVITS {model_version.upper()} 底模缺失,无法加载相应 LoRA 权重") dict_s2 = load_sovits_new(sovits_path) hps = dict_s2["config"] @@ -312,13 +356,9 @@ def get_sovits_weights(sovits_path): hps.model.version = "v1" else: hps.model.version = "v2" - if model_version == "v3": - hps.model.version = "v3" - if model_version == "v4": - hps.model.version = "v4" model_params_dict = vars(hps.model) - if model_version != "v3" and model_version != "v4": + if model_version not in {"v3", "v4"}: vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -326,14 +366,16 @@ def get_sovits_weights(sovits_path): **model_params_dict, ) else: + model_params_dict["version"]=model_version vq_model = SynthesizerTrnV3( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **model_params_dict, ) - init_bigvgan() - model_version = hps.model.version + # init_bigvgan() + init_vocoder(model_version) + logger.info(f"模型版本: {model_version}") if "pretrained" not in sovits_path: try: @@ -345,7 +387,7 @@ def get_sovits_weights(sovits_path): else: vq_model = vq_model.to(device) vq_model.eval() - if if_lora_v3 == False or model_version != "v4": + if model_version not in {"v3", "v4"}: vq_model.load_state_dict(dict_s2["weight"], strict=False) else: if model_version == "v4": @@ -763,7 +805,7 @@ def get_tts_wav( prompt_semantic = codes[0, 0] prompt = prompt_semantic.unsqueeze(0).to(device) - if version != "v3" and version != "v4": + if version not in {"v3", "v4"}: refers = [] if inp_refs: for path in inp_refs: @@ -814,8 +856,7 @@ def get_tts_wav( ) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - - if version != "v3" and version != "v4": + if version not in {"v3", "v4"}: audio = ( vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) .detach() @@ -834,16 +875,18 @@ def get_tts_wav( if sr != 24000: ref_audio = resample(ref_audio, sr) # print("ref_audio",ref_audio.abs().mean()) - mel2 = mel_fn(ref_audio) + mel2 = mel_fn_v4(ref_audio) if version == "v4" else mel_fn(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) mel2 = mel2[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min] - if T_min > 468: - mel2 = mel2[:, :, -468:] - fea_ref = fea_ref[:, :, -468:] - T_min = 468 - chunk_len = 934 - T_min + T_ref = 500 if version == "v4" else 468 + T_chunk = 1000 if version == "v4" else 934 + if T_min > T_ref: + mel2 = mel2[:, :, -T_ref:] + fea_ref = fea_ref[:, :, -T_ref:] + T_min = T_ref + chunk_len = T_chunk - T_min # print("fea_ref",fea_ref,fea_ref.shape) # print("mel2",mel2) mel2 = mel2.to(dtype) @@ -871,7 +914,8 @@ def get_tts_wav( cmf_res = torch.cat(cfm_resss, 2) cmf_res = denorm_spec(cmf_res) if bigvgan_model == None: - init_bigvgan() + # init_bigvgan() + init_vocoder(version) with torch.inference_mode(): wav_gen = bigvgan_model(cmf_res) audio = wav_gen[0][0].cpu().detach().numpy() @@ -905,7 +949,7 @@ def get_tts_wav( if not stream_mode == "normal": if media_type == "wav": sr = 48000 if if_sr else 24000 - sr = hps.data.sampling_rate if version != "v3" and version != "v4" else sr + sr = hps.data.sampling_rate if version != "v3" else sr audio_bytes = pack_wav(audio_bytes, sr) yield audio_bytes.getvalue()