From 6972d02444372b493e221d5c99a8ed0efda3d3d6 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Mon, 26 Aug 2024 16:03:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96tts=5Fconfig?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c677d77..fcb8fb8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -213,6 +213,10 @@ class TTS_Config: "cnhuhbert_base_path": self.cnhuhbert_base_path, } return self.config + + def update_version(self, version:str)->None: + self.version = version + self.languages = self.v2_languages if self.version=="v2" else self.v1_languages def __str__(self): self.configs = self.update_configs() @@ -304,9 +308,11 @@ class TTS: dict_s2 = torch.load(weights_path, map_location=self.configs.device) hps = dict_s2["config"] if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: - self.configs.version = "v1" + self.configs.update_version("v1") else: - self.configs.version = "v2" + self.configs.update_version("v2") + self.configs.save_configs() + hps["model"]["version"] = self.configs.version self.configs.filter_length = hps["data"]["filter_length"] self.configs.segment_size = hps["train"]["segment_size"]