diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 664e1f3..ce5c3be 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -19,7 +19,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer from AR.models.t2s_lightning_module import Text2SemanticLightningModule from feature_extractor.cnhubert import CNHubert -from module.models import SynthesizerTrn +from module.models import SynthesizerTrn, SynthesizerTrnV3 +from peft import LoraConfig, get_peft_model import librosa from time import time as ttime from tools.i18n.i18n import I18nAuto, scan_language_list @@ -29,6 +30,7 @@ from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.TextPreprocessor import TextPreprocessor from BigVGAN.bigvgan import BigVGAN from module.mel_processing import spectrogram_torch,mel_spectrogram_torch +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new language=os.environ.get("language","Auto") language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language i18n = I18nAuto(language=language) @@ -84,6 +86,14 @@ default_v2: t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth version: v2 +default_v3: + bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large + cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base + device: cpu + is_half: false + t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt + vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth + version: v3 """ def set_seed(seed:int): @@ -110,7 +120,7 @@ def set_seed(seed:int): class TTS_Config: default_configs={ - "default":{ + "v1":{ "device": "cpu", "is_half": False, "version": "v1", @@ -119,7 +129,7 @@ class TTS_Config: "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", }, - "default_v2":{ + "v2":{ "device": "cpu", "is_half": False, "version": "v2", @@ -128,6 +138,15 @@ class TTS_Config: "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", }, + "v3":{ + "device": "cpu", + "is_half": False, + "version": "v3", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, } configs:dict = None v1_languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] @@ -164,12 +183,9 @@ class TTS_Config: assert isinstance(configs, dict) version = configs.get("version", "v2").lower() - assert version in ["v1", "v2"] - self.default_configs["default"] = configs.get("default", self.default_configs["default"]) - self.default_configs["default_v2"] = configs.get("default_v2", self.default_configs["default_v2"]) - - default_config_key = "default"if version=="v1" else "default_v2" - self.configs:dict = configs.get("custom", deepcopy(self.default_configs[default_config_key])) + assert version in ["v1", "v2", "v3"] + self.default_configs[version] = configs.get(version, self.default_configs[version]) + self.configs:dict = configs.get("custom", deepcopy(self.default_configs[version])) self.device = self.configs.get("device", torch.device("cpu")) @@ -177,7 +193,7 @@ class TTS_Config: print(f"Warning: CUDA is not available, set device to CPU.") self.device = torch.device("cpu") - # self.is_half = self.configs.get("is_half", False) + self.is_half = self.configs.get("is_half", False) # if str(self.device) == "cpu" and self.is_half: # print(f"Warning: Half precision is not supported on CPU, set is_half to False.") # self.is_half = False @@ -187,22 +203,22 @@ class TTS_Config: self.vits_weights_path = self.configs.get("vits_weights_path", None) self.bert_base_path = self.configs.get("bert_base_path", None) self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) - self.languages = self.v2_languages if self.version=="v2" else self.v1_languages + self.languages = self.v1_languages if self.version=="v1" else self.v2_languages self.is_v3_synthesizer:bool = False if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): - self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path'] + self.t2s_weights_path = self.default_configs[version]['t2s_weights_path'] print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): - self.vits_weights_path = self.default_configs[default_config_key]['vits_weights_path'] + self.vits_weights_path = self.default_configs[version]['vits_weights_path'] print(f"fall back to default vits_weights_path: {self.vits_weights_path}") if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): - self.bert_base_path = self.default_configs[default_config_key]['bert_base_path'] + self.bert_base_path = self.default_configs[version]['bert_base_path'] print(f"fall back to default bert_base_path: {self.bert_base_path}") if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): - self.cnhuhbert_base_path = self.default_configs[default_config_key]['cnhuhbert_base_path'] + self.cnhuhbert_base_path = self.default_configs[version]['cnhuhbert_base_path'] print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") self.update_configs() @@ -254,7 +270,7 @@ class TTS_Config: def update_version(self, version:str)->None: self.version = version - self.languages = self.v2_languages if self.version=="v2" else self.v1_languages + self.languages = self.v1_languages if self.version=="v1" else self.v2_languages def __str__(self): self.configs = self.update_configs() @@ -282,7 +298,7 @@ class TTS: self.configs:TTS_Config = TTS_Config(configs) self.t2s_model:Text2SemanticLightningModule = None - self.vits_model:SynthesizerTrn = None + self.vits_model:Union[SynthesizerTrn, SynthesizerTrnV3] = None self.bert_tokenizer:AutoTokenizer = None self.bert_model:AutoModelForMaskedLM = None self.cnhuhbert_model:CNHubert = None @@ -341,38 +357,81 @@ class TTS: self.bert_model = self.bert_model.half() def init_vits_weights(self, weights_path: str): - print(f"Loading VITS weights from {weights_path}") + self.configs.vits_weights_path = weights_path + version, model_version, if_lora_v3=get_sovits_version_from_path_fast(weights_path) + path_sovits_v3=self.configs.default_configs["v3"]["vits_weights_path"] + + if if_lora_v3==True and os.path.exists(path_sovits_v3)==False: + info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + raise FileExistsError(info) + dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) hps = dict_s2["config"] - if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: - self.configs.update_version("v1") - else: - self.configs.update_version("v2") - self.configs.save_configs() - hps["model"]["version"] = self.configs.version + hps["model"]["semantic_frame_rate"] = "25hz" + if 'enc_p.text_embedding.weight'not in dict_s2['weight']: + hps["model"]["version"] = "v2"#v3model,v2sybomls + elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + hps["model"]["version"] = "v1" + else: + hps["model"]["version"] = "v2" + version = hps["model"]["version"] + self.configs.filter_length = hps["data"]["filter_length"] self.configs.segment_size = hps["train"]["segment_size"] self.configs.sampling_rate = hps["data"]["sampling_rate"] self.configs.hop_length = hps["data"]["hop_length"] self.configs.win_length = hps["data"]["win_length"] self.configs.n_speakers = hps["data"]["n_speakers"] - self.configs.semantic_frame_rate = "25hz" + self.configs.semantic_frame_rate = hps["model"]["semantic_frame_rate"] kwargs = hps["model"] - vits_model = SynthesizerTrn( - self.configs.filter_length // 2 + 1, - self.configs.segment_size // self.configs.hop_length, - n_speakers=self.configs.n_speakers, - **kwargs - ) - if hasattr(vits_model, "enc_q"): - del vits_model.enc_q + self.configs.update_version(version) + + + if model_version!="v3": + vits_model = SynthesizerTrn( + self.configs.filter_length // 2 + 1, + self.configs.segment_size // self.configs.hop_length, + n_speakers=self.configs.n_speakers, + **kwargs + ) + model_version=version + if hasattr(vits_model, "enc_q"): + del vits_model.enc_q + self.configs.is_v3_synthesizer = False + else: + vits_model = SynthesizerTrnV3( + self.configs.filter_length // 2 + 1, + self.configs.segment_size // self.configs.hop_length, + n_speakers=self.configs.n_speakers, + **kwargs + ) + self.configs.is_v3_synthesizer = True + self.init_bigvgan() + + + if if_lora_v3==False: + print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}") + else: + print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)}") + lora_rank=dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) + print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2["weight"], strict=False)}") + + vits_model.cfm = vits_model.cfm.merge_and_unload() + vits_model = vits_model.to(self.configs.device) vits_model = vits_model.eval() - vits_model.load_state_dict(dict_s2["weight"], strict=False) + self.vits_model = vits_model if self.configs.is_half and str(self.configs.device)!="cpu": self.vits_model = self.vits_model.half() @@ -396,6 +455,8 @@ class TTS: def init_bigvgan(self): + if self.bigvgan_model is not None: + return self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions # remove weight norm in the model and set to eval mode self.bigvgan_model.remove_weight_norm() @@ -746,6 +807,7 @@ class TTS: actual_seed = set_seed(seed) parallel_infer = inputs.get("parallel_infer", True) repetition_penalty = inputs.get("repetition_penalty", 1.35) + sample_steps = inputs.get("sample_steps", 16) if parallel_infer: print(i18n("并行推理模式已开启")) @@ -944,30 +1006,41 @@ class TTS: # pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec # )) - if speed_factor == 1.0: - # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] - upsample_rate = math.prod(self.vits_model.upsample_rates) - audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] - audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] - all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) - _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = (self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :]) - audio_frag_end_idx.insert(0, 0) - batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + if not self.configs.is_v3_synthesizer: + if speed_factor == 1.0: + # ## vits并行推理 method 2 + pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] + upsample_rate = math.prod(self.vits_model.upsample_rates) + audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] + audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] + all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) + _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) + _batch_audio_fragment = (self.vits_model.decode( + all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :]) + audio_frag_end_idx.insert(0, 0) + batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] + else: + # ## vits串行推理 + for i, idx in enumerate(idx_list): + phones = batch_phones[i].unsqueeze(0).to(self.configs.device) + _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 + audio_fragment =(self.vits_model.decode( + _pred_semantic, phones, refer_audio_spec, speed=speed_factor + ).detach()[0, 0, :]) + batch_audio_fragment.append( + audio_fragment + ) ###试试重建不带上prompt部分 else: - # ## vits串行推理 for i, idx in enumerate(idx_list): phones = batch_phones[i].unsqueeze(0).to(self.configs.device) _pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment =(self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :]) + audio_fragment = self.v3_synthesis( + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + ) batch_audio_fragment.append( audio_fragment - ) ###试试重建不带上prompt部分 + ) t5 = ttime() t_45 += t5 - t4 @@ -1076,54 +1149,55 @@ class TTS: sample_steps:int=16 ): - prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device) - prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) - refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) + prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device) + prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) + refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) - fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) - ref_audio:torch.Tensor = self.prompt_cache["raw_audio"] - ref_sr = self.prompt_cache["raw_sr"] - ref_audio=ref_audio.to(self.configs.device).float() - if (ref_audio.shape[0] == 2): - ref_audio = ref_audio.mean(0).unsqueeze(0) - if ref_sr!=24000: - ref_audio=resample(ref_audio, ref_sr) - # print("ref_audio",ref_audio.abs().mean()) - mel2 = mel_fn(ref_audio) - mel2 = norm_spec(mel2) - T_min = min(mel2.shape[2], fea_ref.shape[2]) - mel2 = mel2[:, :, :T_min] - fea_ref = fea_ref[:, :, :T_min] - if (T_min > 468): - mel2 = mel2[:, :, -468:] - fea_ref = fea_ref[:, :, -468:] - T_min = 468 - chunk_len = 934 - T_min + fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio:torch.Tensor = self.prompt_cache["raw_audio"] + ref_sr = self.prompt_cache["raw_sr"] + ref_audio=ref_audio.to(self.configs.device).float() + if (ref_audio.shape[0] == 2): + ref_audio = ref_audio.mean(0).unsqueeze(0) + if ref_sr!=24000: + ref_audio=resample(ref_audio, ref_sr) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if (T_min > 468): + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min - mel2=mel2.to(self.precision) - fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + mel2=mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) - cfm_resss = [] - idx = 0 - while (1): - fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len] - if (fea_todo_chunk.shape[-1] == 0): break - idx += chunk_len - fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_resss = [] + idx = 0 + while (1): + fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len] + if (fea_todo_chunk.shape[-1] == 0): break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) - cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) - cfm_res = cfm_res[:, :, mel2.shape[2]:] - mel2 = cfm_res[:, :, -T_min:] - - fea_ref = fea_todo_chunk[:, :, -T_min:] - cfm_resss.append(cfm_res) - cmf_res = torch.cat(cfm_resss, 2) - cmf_res = denorm_spec(cmf_res) - - with torch.inference_mode(): - wav_gen = self.bigvgan_model(cmf_res) - audio=wav_gen[0][0]#.cpu().detach().numpy() + cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) + cfm_res = cfm_res[:, :, mel2.shape[2]:] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res) + + with torch.inference_mode(): + wav_gen = self.bigvgan_model(cmf_res) + audio=wav_gen[0][0]#.cpu().detach().numpy() + + return audio def speed_change(input_audio:np.ndarray, speed:float, sr:int):