diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c768fb3..10988c1 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -4,6 +4,7 @@ import os, sys, gc import random import traceback +import torchaudio from tqdm import tqdm now_dir = os.getcwd() sys.path.append(now_dir) @@ -26,10 +27,37 @@ from tools.my_utils import load_audio from module.mel_processing import spectrogram_torch from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.TextPreprocessor import TextPreprocessor +from BigVGAN.bigvgan import BigVGAN +from module.mel_processing import spectrogram_torch,mel_spectrogram_torch language=os.environ.get("language","Auto") language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language i18n = I18nAuto(language=language) + + +spec_min = -12 +spec_max = 2 +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min +mel_fn=lambda x: mel_spectrogram_torch(x, **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False +}) + + + + + + + # configs/tts_infer.yaml """ custom: @@ -157,6 +185,8 @@ class TTS_Config: self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) self.languages = self.v2_languages if self.version=="v2" else self.v1_languages + self.is_v3_synthesizer:bool = False + if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): self.t2s_weights_path = self.default_configs[default_config_key]['t2s_weights_path'] @@ -252,6 +282,7 @@ class TTS: self.bert_tokenizer:AutoTokenizer = None self.bert_model:AutoModelForMaskedLM = None self.cnhuhbert_model:CNHubert = None + self.bigvgan_model:BigVGAN = None self._init_models() @@ -359,6 +390,19 @@ class TTS: if self.configs.is_half and str(self.configs.device)!="cpu": self.t2s_model = self.t2s_model.half() + + def init_bigvgan(self): + self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + self.bigvgan_model.remove_weight_norm() + self.bigvgan_model = self.bigvgan_model.eval() + if self.configs.is_half == True: + self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device) + else: + self.bigvgan_model = self.bigvgan_model.to(self.configs.device) + + + def enable_half_precision(self, enable: bool = True, save: bool = True): ''' To enable half precision for the TTS model. @@ -383,6 +427,8 @@ class TTS: self.bert_model =self.bert_model.half() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.half() + if self.bigvgan_model is not None: + self.bigvgan_model = self.bigvgan_model.half() else: if self.t2s_model is not None: self.t2s_model = self.t2s_model.float() @@ -392,6 +438,8 @@ class TTS: self.bert_model = self.bert_model.float() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() + if self.bigvgan_model is not None: + self.bigvgan_model = self.bigvgan_model.float() def set_device(self, device: torch.device, save: bool = True): ''' @@ -728,6 +776,9 @@ class TTS: if not no_prompt_text: assert prompt_lang in self.configs.languages + if no_prompt_text and self.configs.is_v3_synthesizer: + raise RuntimeError("prompt_text cannot be empty when using SoVITS_V3") + if ref_audio_path in [None, ""] and \ ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])): raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") @@ -1014,6 +1065,61 @@ class TTS: return sr, audio + def v3_synthesis(self, + semantic_tokens:torch.Tensor, + phones:torch.Tensor, + speed:float=1.0, + sample_steps:int=16 + ): + + prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device) + prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) + refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) + + fea_ref,ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) + ref_audio:torch.Tensor = self.prompt_cache["raw_audio"] + ref_sr = self.prompt_cache["raw_sr"] + ref_audio=ref_audio.to(self.configs.device).float() + if (ref_audio.shape[0] == 2): + ref_audio = ref_audio.mean(0).unsqueeze(0) + if ref_sr!=24000: + ref_audio=resample(ref_audio, ref_sr) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if (T_min > 468): + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min + + mel2=mel2.to(self.precision) + fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) + + cfm_resss = [] + idx = 0 + while (1): + fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len] + if (fea_todo_chunk.shape[-1] == 0): break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + + cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) + cfm_res = cfm_res[:, :, mel2.shape[2]:] + mel2 = cfm_res[:, :, -T_min:] + + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res) + + with torch.inference_mode(): + wav_gen = self.bigvgan_model(cmf_res) + audio=wav_gen[0][0]#.cpu().detach().numpy() + def speed_change(input_audio:np.ndarray, speed:float, sr:int): @@ -1036,3 +1142,14 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int): processed_audio = np.frombuffer(out, np.int16) return processed_audio + + + +resample_transform_dict={} +def resample(audio_tensor, sr0, device): + global resample_transform_dict + if sr0 not in resample_transform_dict: + resample_transform_dict[sr0] = torchaudio.transforms.Resample( + sr0, 24000 + ).to(device) + return resample_transform_dict[sr0](audio_tensor) \ No newline at end of file diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 9def3da..16ff27d 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -118,11 +118,11 @@ class TextPreprocessor: def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: - language = language.replace("all_","") + # language = language.replace("all_","") formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") - if language == "zh": + if language == "all_zh": if re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) @@ -130,7 +130,7 @@ class TextPreprocessor: else: phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) bert = self.get_bert_feature(norm_text, word2ph).to(self.device) - elif language == "yue" and re.search(r'[A-Za-z]', formattext): + elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) return self.get_phones_and_bert(formattext,"yue",version) diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 5a6910d..921ca87 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -194,10 +194,20 @@ def get_weights_names(GPT_weight_root, SoVITS_weight_root): SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) - +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): - tts_pipeline.init_vits_weights(sovits_path) global version, dict_language + version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) + # print(sovits_path,version, model_version, if_lora_v3) + path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth" + + if if_lora_v3 and not os.path.exists(path_sovits_v3): + info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + gr.Warning(info) + raise FileExistsError(info) + + tts_pipeline.init_vits_weights(sovits_path) + dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2 if prompt_language is not None and text_language is not None: if prompt_language in list(dict_language.keys()): @@ -210,7 +220,13 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): else: text_update = {'__type__':'update', 'value':''} text_language_update = {'__type__':'update', 'value':i18n("中文")} - return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update + if model_version=="v3": + visible_sample_steps=True + visible_inp_refs=False + else: + visible_sample_steps=False + visible_inp_refs=True + yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}