diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 04287312..8b432b44 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -469,9 +469,13 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.bert_model = self.bert_model.half() - def init_vits_weights(self, weights_path: str, vocoder_path: str = None): + def init_vits_weights(self, weights_path: str, vocoder_path: str = None, model_version: str = None): self.configs.vits_weights_path = weights_path - version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) + if model_version: + version, _, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) + else: + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) + path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] print(if_lora_v3)