diff --git a/api.py b/api.py index 4b79d501..9dcbdb0c 100644 --- a/api.py +++ b/api.py @@ -313,9 +313,11 @@ def get_sovits_weights(sovits_path): if model_version == "v3": hps.model.version = "v3" + if model_version == "v4": + hps.model.version = "v4" model_params_dict = vars(hps.model) - if model_version != "v3": + if model_version != "v3" and model_version != "v4": vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -759,7 +761,7 @@ def get_tts_wav( prompt_semantic = codes[0, 0] prompt = prompt_semantic.unsqueeze(0).to(device) - if version != "v3": + if version != "v3" and version != "v4": refers = [] if inp_refs: for path in inp_refs: @@ -811,7 +813,7 @@ def get_tts_wav( pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - if version != "v3": + if version != "v3" and version != "v4": audio = ( vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) .detach()