diff --git a/api.py b/api.py index 4b79d501..aa9a6668 100644 --- a/api.py +++ b/api.py @@ -176,9 +176,9 @@ import subprocess class DefaultRefer: def __init__(self, path, text, language): - self.path = args.default_refer_path - self.text = args.default_refer_text - self.language = args.default_refer_language + self.path = path + self.text = text + self.language = language def is_ready(self) -> bool: return is_full(self.path, self.text, self.language) @@ -294,11 +294,13 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new def get_sovits_weights(sovits_path): path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" - is_exist_s2gv3 = os.path.exists(path_sovits_v3) + path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth" version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) - if if_lora_v3 == True and is_exist_s2gv3 == False: + if if_lora_v3 == True and not os.path.exists(path_sovits_v3): logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + if model_version == "v4" and not os.path.exists(path_sovits_v4): + logger.info("SoVITS V4 底模缺失,无法加载相应 LoRA 权重") dict_s2 = load_sovits_new(sovits_path) hps = dict_s2["config"] @@ -310,12 +312,13 @@ def get_sovits_weights(sovits_path): hps.model.version = "v1" else: hps.model.version = "v2" - if model_version == "v3": hps.model.version = "v3" + if model_version == "v4": + hps.model.version = "v4" model_params_dict = vars(hps.model) - if model_version != "v3": + if model_version != "v3" and model_version != "v4": vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -342,10 +345,13 @@ def get_sovits_weights(sovits_path): else: vq_model = vq_model.to(device) vq_model.eval() - if if_lora_v3 == False: + if if_lora_v3 == False or model_version != "v4": vq_model.load_state_dict(dict_s2["weight"], strict=False) else: - vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) + if model_version == "v4": + vq_model.load_state_dict(load_sovits_new(path_sovits_v4)["weight"], strict=False) + else: + vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) lora_rank = dict_s2["lora_rank"] lora_config = LoraConfig( target_modules=["to_k", "to_q", "to_v", "to_out.0"], @@ -394,13 +400,11 @@ def change_gpt_sovits_weights(gpt_path, sovits_path): try: gpt = get_gpt_weights(gpt_path) sovits = get_sovits_weights(sovits_path) + speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) except Exception as e: return JSONResponse({"code": 400, "message": str(e)}, status_code=400) - - speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) - return JSONResponse({"code": 0, "message": "Success"}, status_code=200) - - + def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") @@ -759,7 +763,7 @@ def get_tts_wav( prompt_semantic = codes[0, 0] prompt = prompt_semantic.unsqueeze(0).to(device) - if version != "v3": + if version != "v3" and version != "v4": refers = [] if inp_refs: for path in inp_refs: @@ -811,7 +815,7 @@ def get_tts_wav( pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) t3 = ttime() - if version != "v3": + if version != "v3" and version != "v4": audio = ( vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) .detach() @@ -880,7 +884,7 @@ def get_tts_wav( audio_opt = np.concatenate(audio_opt, 0) t4 = ttime() - sr = hps.data.sampling_rate if version != "v3" else 24000 + sr = hps.data.sampling_rate if version != "v3" and version != "v4" else 24000 if if_sr and sr == 24000: audio_opt = torch.from_numpy(audio_opt).float().to(device) audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) @@ -901,7 +905,7 @@ def get_tts_wav( if not stream_mode == "normal": if media_type == "wav": sr = 48000 if if_sr else 24000 - sr = hps.data.sampling_rate if version != "v3" else sr + sr = hps.data.sampling_rate if version != "v3" and version != "v4" else sr audio_bytes = pack_wav(audio_bytes, sr) yield audio_bytes.getvalue()