diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 3924ff65..b925a382 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -214,7 +214,7 @@ def resample(audio_tensor, sr0): ###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt #symbol_version-model_version-if_lora_v3 from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new -def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): +def change_sovits_weights(sovits_path,prompt_language=None,text_language=None, cat=None): print("the sovits model is updated!") global vq_model, hps, version, model_version, dict_language,if_lora_v3 version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) @@ -254,7 +254,7 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): else: hps.model.version = "v2" version=hps.model.version - # print("sovits版本:",hps.model.version) + print("sovits版本:",hps.model.version) if model_version!="v3": vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, @@ -279,6 +279,7 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): else: vq_model = vq_model.to(device) vq_model.eval() + if if_lora_v3==False: print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False)) else: