mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-25 06:45:29 +08:00
完善api.py对于v4模型的兼容
完善api.py对于v4模型的兼容。
This commit is contained in:
parent
fa971a4e09
commit
ffb520ee54
92
api.py
92
api.py
@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
@ -214,6 +214,38 @@ def init_bigvgan():
|
|||||||
else:
|
else:
|
||||||
bigvgan_model = bigvgan_model.to(device)
|
bigvgan_model = bigvgan_model.to(device)
|
||||||
|
|
||||||
|
def init_vocoder(version: str):
|
||||||
|
global bigvgan_model
|
||||||
|
from BigVGAN import bigvgan
|
||||||
|
|
||||||
|
if version == "v3":
|
||||||
|
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||||
|
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||||
|
use_cuda_kernel=False,
|
||||||
|
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||||
|
# remove weight norm in the model and set to eval mode
|
||||||
|
bigvgan_model.remove_weight_norm()
|
||||||
|
bigvgan_model = bigvgan_model.eval()
|
||||||
|
|
||||||
|
elif version == "v4":
|
||||||
|
bigvgan_model = Generator(
|
||||||
|
initial_channel=100,
|
||||||
|
resblock="1",
|
||||||
|
resblock_kernel_sizes=[3, 7, 11],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
upsample_rates=[10, 6, 2, 2, 2],
|
||||||
|
upsample_initial_channel=512,
|
||||||
|
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||||
|
gin_channels=0, is_bias=True
|
||||||
|
)
|
||||||
|
bigvgan_model.remove_weight_norm()
|
||||||
|
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
||||||
|
bigvgan_model.load_state_dict(state_dict_g)
|
||||||
|
|
||||||
|
if is_half == True:
|
||||||
|
bigvgan_model = bigvgan_model.half().to(device)
|
||||||
|
else:
|
||||||
|
bigvgan_model = bigvgan_model.to(device)
|
||||||
|
|
||||||
resample_transform_dict = {}
|
resample_transform_dict = {}
|
||||||
|
|
||||||
@ -253,6 +285,20 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||||
|
x,
|
||||||
|
**{
|
||||||
|
"n_fft": 1280,
|
||||||
|
"win_size": 1280,
|
||||||
|
"hop_size": 320,
|
||||||
|
"num_mels": 100,
|
||||||
|
"sampling_rate": 32000,
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": None,
|
||||||
|
"center": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
sr_model = None
|
sr_model = None
|
||||||
|
|
||||||
@ -297,10 +343,8 @@ def get_sovits_weights(sovits_path):
|
|||||||
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
||||||
|
|
||||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||||
if if_lora_v3 == True and not os.path.exists(path_sovits_v3):
|
if (if_lora_v3 == True and not os.path.exists(path_sovits_v3)) or (model_version == "v4" and not os.path.exists(path_sovits_v4)):
|
||||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
logger.info(f"SoVITS {model_version.upper()} 底模缺失,无法加载相应 LoRA 权重")
|
||||||
if model_version == "v4" and not os.path.exists(path_sovits_v4):
|
|
||||||
logger.info("SoVITS V4 底模缺失,无法加载相应 LoRA 权重")
|
|
||||||
|
|
||||||
dict_s2 = load_sovits_new(sovits_path)
|
dict_s2 = load_sovits_new(sovits_path)
|
||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
@ -312,13 +356,9 @@ def get_sovits_weights(sovits_path):
|
|||||||
hps.model.version = "v1"
|
hps.model.version = "v1"
|
||||||
else:
|
else:
|
||||||
hps.model.version = "v2"
|
hps.model.version = "v2"
|
||||||
if model_version == "v3":
|
|
||||||
hps.model.version = "v3"
|
|
||||||
if model_version == "v4":
|
|
||||||
hps.model.version = "v4"
|
|
||||||
|
|
||||||
model_params_dict = vars(hps.model)
|
model_params_dict = vars(hps.model)
|
||||||
if model_version != "v3" and model_version != "v4":
|
if model_version not in {"v3", "v4"}:
|
||||||
vq_model = SynthesizerTrn(
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
@ -326,14 +366,16 @@ def get_sovits_weights(sovits_path):
|
|||||||
**model_params_dict,
|
**model_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
model_params_dict["version"]=model_version
|
||||||
vq_model = SynthesizerTrnV3(
|
vq_model = SynthesizerTrnV3(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**model_params_dict,
|
**model_params_dict,
|
||||||
)
|
)
|
||||||
init_bigvgan()
|
# init_bigvgan()
|
||||||
model_version = hps.model.version
|
init_vocoder(model_version)
|
||||||
|
|
||||||
logger.info(f"模型版本: {model_version}")
|
logger.info(f"模型版本: {model_version}")
|
||||||
if "pretrained" not in sovits_path:
|
if "pretrained" not in sovits_path:
|
||||||
try:
|
try:
|
||||||
@ -345,7 +387,7 @@ def get_sovits_weights(sovits_path):
|
|||||||
else:
|
else:
|
||||||
vq_model = vq_model.to(device)
|
vq_model = vq_model.to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
if if_lora_v3 == False or model_version != "v4":
|
if model_version not in {"v3", "v4"}:
|
||||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
else:
|
else:
|
||||||
if model_version == "v4":
|
if model_version == "v4":
|
||||||
@ -763,7 +805,7 @@ def get_tts_wav(
|
|||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
|
|
||||||
if version != "v3" and version != "v4":
|
if version not in {"v3", "v4"}:
|
||||||
refers = []
|
refers = []
|
||||||
if inp_refs:
|
if inp_refs:
|
||||||
for path in inp_refs:
|
for path in inp_refs:
|
||||||
@ -814,8 +856,7 @@ def get_tts_wav(
|
|||||||
)
|
)
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
|
if version not in {"v3", "v4"}:
|
||||||
if version != "v3" and version != "v4":
|
|
||||||
audio = (
|
audio = (
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
||||||
.detach()
|
.detach()
|
||||||
@ -834,16 +875,18 @@ def get_tts_wav(
|
|||||||
if sr != 24000:
|
if sr != 24000:
|
||||||
ref_audio = resample(ref_audio, sr)
|
ref_audio = resample(ref_audio, sr)
|
||||||
# print("ref_audio",ref_audio.abs().mean())
|
# print("ref_audio",ref_audio.abs().mean())
|
||||||
mel2 = mel_fn(ref_audio)
|
mel2 = mel_fn_v4(ref_audio) if version == "v4" else mel_fn(ref_audio)
|
||||||
mel2 = norm_spec(mel2)
|
mel2 = norm_spec(mel2)
|
||||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
mel2 = mel2[:, :, :T_min]
|
mel2 = mel2[:, :, :T_min]
|
||||||
fea_ref = fea_ref[:, :, :T_min]
|
fea_ref = fea_ref[:, :, :T_min]
|
||||||
if T_min > 468:
|
T_ref = 500 if version == "v4" else 468
|
||||||
mel2 = mel2[:, :, -468:]
|
T_chunk = 1000 if version == "v4" else 934
|
||||||
fea_ref = fea_ref[:, :, -468:]
|
if T_min > T_ref:
|
||||||
T_min = 468
|
mel2 = mel2[:, :, -T_ref:]
|
||||||
chunk_len = 934 - T_min
|
fea_ref = fea_ref[:, :, -T_ref:]
|
||||||
|
T_min = T_ref
|
||||||
|
chunk_len = T_chunk - T_min
|
||||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
# print("fea_ref",fea_ref,fea_ref.shape)
|
||||||
# print("mel2",mel2)
|
# print("mel2",mel2)
|
||||||
mel2 = mel2.to(dtype)
|
mel2 = mel2.to(dtype)
|
||||||
@ -871,7 +914,8 @@ def get_tts_wav(
|
|||||||
cmf_res = torch.cat(cfm_resss, 2)
|
cmf_res = torch.cat(cfm_resss, 2)
|
||||||
cmf_res = denorm_spec(cmf_res)
|
cmf_res = denorm_spec(cmf_res)
|
||||||
if bigvgan_model == None:
|
if bigvgan_model == None:
|
||||||
init_bigvgan()
|
# init_bigvgan()
|
||||||
|
init_vocoder(version)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
wav_gen = bigvgan_model(cmf_res)
|
wav_gen = bigvgan_model(cmf_res)
|
||||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||||
@ -905,7 +949,7 @@ def get_tts_wav(
|
|||||||
if not stream_mode == "normal":
|
if not stream_mode == "normal":
|
||||||
if media_type == "wav":
|
if media_type == "wav":
|
||||||
sr = 48000 if if_sr else 24000
|
sr = 48000 if if_sr else 24000
|
||||||
sr = hps.data.sampling_rate if version != "v3" and version != "v4" else sr
|
sr = hps.data.sampling_rate if version != "v3" else sr
|
||||||
audio_bytes = pack_wav(audio_bytes, sr)
|
audio_bytes = pack_wav(audio_bytes, sr)
|
||||||
yield audio_bytes.getvalue()
|
yield audio_bytes.getvalue()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user