diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 426929f..f03183a 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -159,6 +159,10 @@ class TextPreprocessor: textlist.append(tmp["text"]) else: for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue if tmp["lang"] == "en": langlist.append(tmp["lang"]) else: diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index ce6cf9d..5c7d010 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -623,6 +623,10 @@ def get_phones_and_bert(text, language, version, final=False): textlist.append(tmp["text"]) else: for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue if tmp["lang"] == "en": langlist.append(tmp["lang"]) else: diff --git a/api.py b/api.py index 7354ff7..b7e94e7 100644 --- a/api.py +++ b/api.py @@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np from feature_extractor import cnhubert from io import BytesIO -from module.models import SynthesizerTrn, SynthesizerTrnV3 +from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3 from peft import LoraConfig, get_peft_model from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence @@ -198,8 +198,38 @@ def is_full(*items): # 任意一项为空返回False return True -def init_bigvgan(): +bigvgan_model = hifigan_model = sv_cn_model = None +def clean_hifigan_model(): + global hifigan_model + if hifigan_model: + hifigan_model = hifigan_model.cpu() + hifigan_model = None + try: + torch.cuda.empty_cache() + except: + pass +def clean_bigvgan_model(): global bigvgan_model + if bigvgan_model: + bigvgan_model = bigvgan_model.cpu() + bigvgan_model = None + try: + torch.cuda.empty_cache() + except: + pass +def clean_sv_cn_model(): + global sv_cn_model + if sv_cn_model: + sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu() + sv_cn_model = None + try: + torch.cuda.empty_cache() + except: + pass + + +def init_bigvgan(): + global bigvgan_model, hifigan_model,sv_cn_model from BigVGAN import bigvgan bigvgan_model = bigvgan.BigVGAN.from_pretrained( @@ -209,20 +239,53 @@ def init_bigvgan(): # remove weight norm in the model and set to eval mode bigvgan_model.remove_weight_norm() bigvgan_model = bigvgan_model.eval() + if is_half == True: bigvgan_model = bigvgan_model.half().to(device) else: bigvgan_model = bigvgan_model.to(device) -resample_transform_dict = {} +def init_hifigan(): + global hifigan_model, bigvgan_model,sv_cn_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, + is_bias=True, + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load( + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False + ) + print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) -def resample(audio_tensor, sr0): +from sv import SV +def init_sv_cn(): + global hifigan_model, bigvgan_model, sv_cn_model + sv_cn_model = SV(device, is_half) + + +resample_transform_dict={} +def resample(audio_tensor, sr0,sr1,device): global resample_transform_dict - if sr0 not in resample_transform_dict: - resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device) - return resample_transform_dict[sr0](audio_tensor) + key="%s-%s-%s"%(sr0,sr1,str(device)) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample( + sr0, sr1 + ).to(device) + return resample_transform_dict[key](audio_tensor) from module.mel_processing import mel_spectrogram_torch @@ -252,6 +315,19 @@ mel_fn = lambda x: mel_spectrogram_torch( "center": False, }, ) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) sr_model = None @@ -293,12 +369,18 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new def get_sovits_weights(sovits_path): - path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + from config import pretrained_sovits_name + path_sovits_v3 = pretrained_sovits_name["v3"] + path_sovits_v4 = pretrained_sovits_name["v4"] is_exist_s2gv3 = os.path.exists(path_sovits_v3) + is_exist_s2gv4 = os.path.exists(path_sovits_v4) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) - if if_lora_v3 == True and is_exist_s2gv3 == False: - logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + + if if_lora_v3 == True and is_exist == False: + logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version) dict_s2 = load_sovits_new(sovits_path) hps = dict_s2["config"] @@ -311,11 +393,13 @@ def get_sovits_weights(sovits_path): else: hps.model.version = "v2" - if model_version == "v3": - hps.model.version = "v3" - model_params_dict = vars(hps.model) - if model_version != "v3": + if model_version not in {"v3", "v4"}: + if "Pro" in model_version: + hps.model.version = model_version + if sv_cn_model == None: + init_sv_cn() + vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -323,13 +407,18 @@ def get_sovits_weights(sovits_path): **model_params_dict, ) else: + hps.model.version = model_version vq_model = SynthesizerTrnV3( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **model_params_dict, ) - init_bigvgan() + if model_version == "v3": + init_bigvgan() + if model_version == "v4": + init_hifigan() + model_version = hps.model.version logger.info(f"模型版本: {model_version}") if "pretrained" not in sovits_path: @@ -345,7 +434,8 @@ def get_sovits_weights(sovits_path): if if_lora_v3 == False: vq_model.load_state_dict(dict_s2["weight"], strict=False) else: - vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False) lora_rank = dict_s2["lora_rank"] lora_config = LoraConfig( target_modules=["to_k", "to_q", "to_v", "to_out.0"], @@ -479,6 +569,10 @@ def get_phones_and_bert(text, language, version, final=False): textlist.append(tmp["text"]) else: for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue if tmp["lang"] == "en": langlist.append(tmp["lang"]) else: @@ -533,23 +627,32 @@ class DictToAttrRecursive(dict): raise AttributeError(f"Attribute {item} not found") -def get_spepc(hps, filename): - audio, _ = librosa.load(filename, sr=int(hps.data.sampling_rate)) - audio = torch.FloatTensor(audio) +def get_spepc(hps, filename, dtype, device, is_v2pro=False): + sr1=int(hps.data.sampling_rate) + audio, sr0=torchaudio.load(filename) + if sr0!=sr1: + audio=audio.to(device) + if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0) + audio=resample(audio,sr0,sr1,device) + else: + audio=audio.to(device) + if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0) + maxx = audio.abs().max() if maxx > 1: audio /= min(2, maxx) - audio_norm = audio - audio_norm = audio_norm.unsqueeze(0) spec = spectrogram_torch( - audio_norm, + audio, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, center=False, ) - return spec + spec=spec.to(dtype) + if is_v2pro==True: + audio=resample(audio,sr1,16000,device).to(dtype) + return spec,audio def pack_audio(audio_bytes, data, rate): @@ -736,6 +839,16 @@ def get_tts_wav( t2s_model = infer_gpt.t2s_model max_sec = infer_gpt.max_sec + if version == "v3": + if sample_steps not in [4, 8, 16, 32, 64, 128]: + sample_steps = 32 + elif version == "v4": + if sample_steps not in [4, 8, 16, 32]: + sample_steps = 8 + + if if_sr and version != "v3": + if_sr = False + t0 = ttime() prompt_text = prompt_text.strip("\n") if prompt_text[-1] not in splits: @@ -759,19 +872,29 @@ def get_tts_wav( prompt_semantic = codes[0, 0] prompt = prompt_semantic.unsqueeze(0).to(device) - if version != "v3": + is_v2pro = version in {"v2Pro","v2ProPlus"} + if version not in {"v3", "v4"}: refers = [] + if is_v2pro: + sv_emb= [] + if sv_cn_model == None: + init_sv_cn() if inp_refs: for path in inp_refs: - try: - refer = get_spepc(hps, path).to(dtype).to(device) + try:#####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer + refer,audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) refers.append(refer) + if is_v2pro: + sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) except Exception as e: logger.error(e) if len(refers) == 0: - refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + refers,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) + refers=[refers] + if is_v2pro: + sv_emb=[sv_cn_model.compute_embedding3(audio_tensor)] else: - refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) + refer,audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) t1 = ttime() # os.environ['version'] = version @@ -811,41 +934,48 @@ def get_tts_wav( pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - if version != "v3": - audio = ( - vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) - .detach() - .cpu() - .numpy()[0, 0] - ) ###试试重建不带上prompt部分 + if version not in {"v3", "v4"}: + if is_v2pro: + audio = ( + vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb) + .detach() + .cpu() + .numpy()[0, 0] + ) + else: + audio = ( + vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) + .detach() + .cpu() + .numpy()[0, 0] + ) else: phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) - # print(11111111, phoneme_ids0, phoneme_ids1) + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) ref_audio, sr = torchaudio.load(ref_wav_path) ref_audio = ref_audio.to(device).float() if ref_audio.shape[0] == 2: ref_audio = ref_audio.mean(0).unsqueeze(0) - if sr != 24000: - ref_audio = resample(ref_audio, sr) - # print("ref_audio",ref_audio.abs().mean()) - mel2 = mel_fn(ref_audio) + + tgt_sr = 24000 if version == "v3" else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr, tgt_sr, device) + mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) mel2 = mel2[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min] - if T_min > 468: - mel2 = mel2[:, :, -468:] - fea_ref = fea_ref[:, :, -468:] - T_min = 468 - chunk_len = 934 - T_min - # print("fea_ref",fea_ref,fea_ref.shape) - # print("mel2",mel2) + Tref = 468 if version == "v3" else 500 + Tchunk = 934 if version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min mel2 = mel2.to(dtype) fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) - # print("fea_todo",fea_todo) - # print("ge",ge.abs().mean()) cfm_resss = [] idx = 0 while 1: @@ -854,22 +984,24 @@ def get_tts_wav( break idx += chunk_len fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) - # set_seed(123) cfm_res = vq_model.cfm.inference( fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 ) cfm_res = cfm_res[:, :, mel2.shape[2] :] mel2 = cfm_res[:, :, -T_min:] - # print("fea", fea) - # print("mel2in", mel2) fea_ref = fea_todo_chunk[:, :, -T_min:] cfm_resss.append(cfm_res) - cmf_res = torch.cat(cfm_resss, 2) - cmf_res = denorm_spec(cmf_res) - if bigvgan_model == None: - init_bigvgan() + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + if version == "v3": + if bigvgan_model == None: + init_bigvgan() + else: # v4 + if hifigan_model == None: + init_hifigan() + vocoder_model = bigvgan_model if version == "v3" else hifigan_model with torch.inference_mode(): - wav_gen = bigvgan_model(cmf_res) + wav_gen = vocoder_model(cfm_res) audio = wav_gen[0][0].cpu().detach().numpy() max_audio = np.abs(audio).max() @@ -880,7 +1012,13 @@ def get_tts_wav( audio_opt = np.concatenate(audio_opt, 0) t4 = ttime() - sr = hps.data.sampling_rate if version != "v3" else 24000 + if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + sr = 32000 + elif version == "v3": + sr = 24000 + else: + sr = 48000 # v4 + if if_sr and sr == 24000: audio_opt = torch.from_numpy(audio_opt).float().to(device) audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) @@ -900,8 +1038,12 @@ def get_tts_wav( if not stream_mode == "normal": if media_type == "wav": - sr = 48000 if if_sr else 24000 - sr = hps.data.sampling_rate if version != "v3" else sr + if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + sr = 32000 + elif version == "v3": + sr = 48000 if if_sr else 24000 + else: + sr = 48000 # v4 audio_bytes = pack_wav(audio_bytes, sr) yield audio_bytes.getvalue() @@ -966,8 +1108,6 @@ def handle( if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) - if sample_steps not in [4, 8, 16, 32]: - sample_steps = 32 if cut_punc == None: text = cut_text(text, default_cut_punc)