From fe2f04bdb86643f711f4cfa70ea944561eddf11e Mon Sep 17 00:00:00 2001 From: KamioRinn <63162909+KamioRinn@users.noreply.github.com> Date: Wed, 5 Mar 2025 17:13:46 +0800 Subject: [PATCH] API for V3 (#2154) --- GPT_SoVITS/module/models.py | 2 + GPT_SoVITS/text/chinese.py | 2 + GPT_SoVITS/text/chinese2.py | 4 +- api.py | 273 +++++++++++++++++++++++++++++------- 4 files changed, 226 insertions(+), 55 deletions(-) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 623da80..33bd607 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1162,6 +1162,7 @@ class SynthesizerTrnV3(nn.Module): use_sdp=True, semantic_frame_rate=None, freeze_quantizer=None, + version="v3", **kwargs): super().__init__() @@ -1182,6 +1183,7 @@ class SynthesizerTrnV3(nn.Module): self.segment_size = segment_size self.n_speakers = n_speakers self.gin_channels = gin_channels + self.version = version self.model_dim=512 self.use_sdp = use_sdp diff --git a/GPT_SoVITS/text/chinese.py b/GPT_SoVITS/text/chinese.py index 2255c6e..55dc997 100644 --- a/GPT_SoVITS/text/chinese.py +++ b/GPT_SoVITS/text/chinese.py @@ -17,6 +17,8 @@ pinyin_to_symbol_map = { for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() } +import jieba_fast, logging +jieba_fast.setLogLevel(logging.CRITICAL) import jieba_fast.posseg as psg diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index f716b41..2b4599d 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -18,13 +18,15 @@ pinyin_to_symbol_map = { for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() } +import jieba_fast, logging +jieba_fast.setLogLevel(logging.CRITICAL) import jieba_fast.posseg as psg # is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启 # is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False if is_g2pw: - print("当前使用g2pw进行拼音推理") + # print("当前使用g2pw进行拼音推理") from text.g2pw import G2PWPinyin, correct_pronunciation parent_directory = os.path.dirname(current_file_path) g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True) diff --git a/api.py b/api.py index c5f7024..d92d9c8 100644 --- a/api.py +++ b/api.py @@ -150,9 +150,9 @@ sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import signal -import LangSegment +from text.LangSegmenter import LangSegmenter from time import time as ttime -import torch +import torch, torchaudio import librosa import soundfile as sf from fastapi import FastAPI, Request, Query, HTTPException @@ -162,7 +162,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np from feature_extractor import cnhubert from io import BytesIO -from module.models import SynthesizerTrn +from module.models import SynthesizerTrn, SynthesizerTrnV3 +from peft import LoraConfig, PeftModel, get_peft_model from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text @@ -197,6 +198,61 @@ def is_full(*items): # 任意一项为空返回False return True +def init_bigvgan(): + global bigvgan_model + from BigVGAN import bigvgan + bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +resample_transform_dict={} +def resample(audio_tensor, sr0): + global resample_transform_dict + if sr0 not in resample_transform_dict: + resample_transform_dict[sr0] = torchaudio.transforms.Resample( + sr0, 24000 + ).to(device) + return resample_transform_dict[sr0](audio_tensor) + + +from module.mel_processing import spectrogram_torch,mel_spectrogram_torch +spec_min = -12 +spec_max = 2 +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min +mel_fn=lambda x: mel_spectrogram_torch(x, **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False +}) + + +sr_model=None +def audio_sr(audio,sr): + global sr_model + if sr_model==None: + from tools.audio_sr import AP_BWE + try: + sr_model=AP_BWE(device,DictToAttrRecursive) + except FileNotFoundError: + logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载") + return audio.cpu().detach().numpy(),sr + return sr_model(audio,sr) + + class Speaker: def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None): self.name = name @@ -214,31 +270,72 @@ class Sovits: self.vq_model = vq_model self.hps = hps +from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new def get_sovits_weights(sovits_path): - dict_s2 = torch.load(sovits_path, map_location="cpu") + path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth" + is_exist_s2gv3=os.path.exists(path_sovits_v3) + + version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) + if if_lora_v3==True and is_exist_s2gv3==False: + logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + + dict_s2 = load_sovits_new(sovits_path) hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" - if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + if 'enc_p.text_embedding.weight' not in dict_s2['weight']: + hps.model.version = "v2"#v3model,v2sybomls + elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: hps.model.version = "v1" else: hps.model.version = "v2" - logger.info(f"模型版本: {hps.model.version}") + + if model_version == "v3": + hps.model.version = "v3" + model_params_dict = vars(hps.model) - vq_model = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **model_params_dict - ) + if model_version!="v3": + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict + ) + else: + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict + ) + init_bigvgan() + model_version=hps.model.version + logger.info(f"模型版本: {model_version}") if ("pretrained" not in sovits_path): - del vq_model.enc_q + try: + del vq_model.enc_q + except:pass if is_half == True: vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() - vq_model.load_state_dict(dict_s2["weight"], strict=False) + if if_lora_v3 == False: + vq_model.load_state_dict(dict_s2["weight"], strict=False) + else: + vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) + lora_rank=dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() sovits = Sovits(vq_model, hps) return sovits @@ -260,8 +357,8 @@ def get_gpt_weights(gpt_path): t2s_model = t2s_model.half() t2s_model = t2s_model.to(device) t2s_model.eval() - total = sum([param.nelement() for param in t2s_model.parameters()]) - logger.info("Number of parameter: %.2fM" % (total / 1e6)) + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # logger.info("Number of parameter: %.2fM" % (total / 1e6)) gpt = Gpt(max_sec, t2s_model) return gpt @@ -295,6 +392,7 @@ def get_bert_feature(text, word2ph): def clean_text_inf(text, language, version): + language = language.replace("all_","") phones, word2ph, norm_text = clean_text(text, language, version) phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text @@ -315,16 +413,10 @@ def get_bert_inf(phones, word2ph, norm_text, language): from text import chinese def get_phones_and_bert(text,language,version,final=False): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: - language = language.replace("all_","") - if language == "en": - LangSegment.setfilters(["en"]) - formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - formattext = text + formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") - if language == "zh": + if language == "all_zh": if re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) @@ -332,7 +424,7 @@ def get_phones_and_bert(text,language,version,final=False): else: phones, word2ph, norm_text = clean_text_inf(formattext, language, version) bert = get_bert_feature(norm_text, word2ph).to(device) - elif language == "yue" and re.search(r'[A-Za-z]', formattext): + elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.mix_text_normalize(formattext) return get_phones_and_bert(formattext,"yue",version) @@ -345,19 +437,18 @@ def get_phones_and_bert(text,language,version,final=False): elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: textlist=[] langlist=[] - LangSegment.setfilters(["zh","ja","en","ko"]) if language == "auto": - for tmp in LangSegment.getTexts(text): + for tmp in LangSegmenter.getTexts(text): langlist.append(tmp["lang"]) textlist.append(tmp["text"]) elif language == "auto_yue": - for tmp in LangSegment.getTexts(text): + for tmp in LangSegmenter.getTexts(text): if tmp["lang"] == "zh": tmp["lang"] = "yue" langlist.append(tmp["lang"]) textlist.append(tmp["text"]) else: - for tmp in LangSegment.getTexts(text): + for tmp in LangSegmenter.getTexts(text): if tmp["lang"] == "en": langlist.append(tmp["lang"]) else: @@ -556,10 +647,11 @@ def only_punc(text): splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"): infer_sovits = speaker_list[spk].sovits vq_model = infer_sovits.vq_model hps = infer_sovits.hps + version = vq_model.version infer_gpt = speaker_list[spk].gpt t2s_model = infer_gpt.t2s_model @@ -587,20 +679,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_semantic = codes[0, 0] prompt = prompt_semantic.unsqueeze(0).to(device) - refers=[] - if(inp_refs): - for path in inp_refs: - try: - refer = get_spepc(hps, path).to(dtype).to(device) - refers.append(refer) - except Exception as e: - logger.error(e) - if(len(refers)==0): - refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + if version != "v3": + refers=[] + if(inp_refs): + for path in inp_refs: + try: + refer = get_spepc(hps, path).to(dtype).to(device) + refers.append(refer) + except Exception as e: + logger.error(e) + if(len(refers)==0): + refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + else: + refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) t1 = ttime() - version = vq_model.version - os.environ['version'] = version + # os.environ['version'] = version prompt_language = dict_language[prompt_language.lower()] text_language = dict_language[text_language.lower()] phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) @@ -634,20 +728,82 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, early_stop_num=hz * max_sec) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - audio = \ - vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refers,speed=speed).detach().cpu().numpy()[ - 0, 0] ###试试重建不带上prompt部分 + + if version != "v3": + audio = \ + vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), + refers,speed=speed).detach().cpu().numpy()[ + 0, 0] ###试试重建不带上prompt部分 + else: + phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0) + # print(11111111, phoneme_ids0, phoneme_ids1) + fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio=ref_audio.to(device).float() + if (ref_audio.shape[0] == 2): + ref_audio = ref_audio.mean(0).unsqueeze(0) + if sr!=24000: + ref_audio=resample(ref_audio,sr) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if (T_min > 468): + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min + # print("fea_ref",fea_ref,fea_ref.shape) + # print("mel2",mel2) + mel2=mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed) + # print("fea_todo",fea_todo) + # print("ge",ge.abs().mean()) + cfm_resss = [] + idx = 0 + while (1): + fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len] + if (fea_todo_chunk.shape[-1] == 0): break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + # set_seed(123) + cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) + cfm_res = cfm_res[:, :, mel2.shape[2]:] + mel2 = cfm_res[:, :, -T_min:] + # print("fea", fea) + # print("mel2in", mel2) + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res) + if bigvgan_model==None:init_bigvgan() + with torch.inference_mode(): + wav_gen = bigvgan_model(cmf_res) + audio=wav_gen[0][0].cpu().detach().numpy() + max_audio=np.abs(audio).max() if max_audio>1: audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) + audio_opt = np.concatenate(audio_opt, 0) t4 = ttime() + + sr = hps.data.sampling_rate if version != "v3" else 24000 + if if_sr and sr == 24000: + audio_opt = torch.from_numpy(audio_opt).float().to(device) + audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr) + max_audio=np.abs(audio_opt).max() + if max_audio > 1: audio_opt /= max_audio + sr = 48000 + if is_int32: - audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate) + audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr) else: - audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) + audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) if stream_mode == "normal": audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) @@ -655,7 +811,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if not stream_mode == "normal": if media_type == "wav": - audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate) + sr = 48000 if if_sr else 24000 + sr = hps.data.sampling_rate if version != "v3" else sr + audio_bytes = pack_wav(audio_bytes,sr) yield audio_bytes.getvalue() @@ -688,7 +846,7 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs): +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -702,12 +860,15 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + if not sample_steps in [4,8,16,32]: + sample_steps = 32 + if cut_punc == None: text = cut_text(text,default_cut_punc) else: text = cut_text(text,cut_punc) - return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type) + return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type) @@ -915,7 +1076,9 @@ async def tts_endpoint(request: Request): json_post_raw.get("top_p", 1.0), json_post_raw.get("temperature", 1.0), json_post_raw.get("speed", 1.0), - json_post_raw.get("inp_refs", []) + json_post_raw.get("inp_refs", []), + json_post_raw.get("sample_steps", 32), + json_post_raw.get("if_sr", False) ) @@ -931,9 +1094,11 @@ async def tts_endpoint( top_p: float = 1.0, temperature: float = 1.0, speed: float = 1.0, - inp_refs: list = Query(default=[]) + inp_refs: list = Query(default=[]), + sample_steps: int = 32, + if_sr: bool = False ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr) if __name__ == "__main__":