From f35f6e9b5e529f7bd3c6db36d163a97715ce5f24 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Thu, 29 Aug 2024 00:33:07 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96tts=5Fconfig=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E9=80=BB=E8=BE=91=20(#1538)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 优化tts_config * fix * 优化报错提示 * 优化报错提示 --- GPT_SoVITS/TTS_infer_pack/TTS.py | 11 ++++++++--- api_v2.py | 6 +++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index c677d77..a1eeb28 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -213,6 +213,10 @@ class TTS_Config: "cnhuhbert_base_path": self.cnhuhbert_base_path, } return self.config + + def update_version(self, version:str)->None: + self.version = version + self.languages = self.v2_languages if self.version=="v2" else self.v1_languages def __str__(self): self.configs = self.update_configs() @@ -300,13 +304,14 @@ class TTS: def init_vits_weights(self, weights_path: str): print(f"Loading VITS weights from {weights_path}") self.configs.vits_weights_path = weights_path - self.configs.save_configs() dict_s2 = torch.load(weights_path, map_location=self.configs.device) hps = dict_s2["config"] if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: - self.configs.version = "v1" + self.configs.update_version("v1") else: - self.configs.version = "v2" + self.configs.update_version("v2") + self.configs.save_configs() + hps["model"]["version"] = self.configs.version self.configs.filter_length = hps["data"]["filter_length"] self.configs.segment_size = hps["train"]["segment_size"] diff --git a/api_v2.py b/api_v2.py index a9faaeb..ea1d0c7 100644 --- a/api_v2.py +++ b/api_v2.py @@ -253,13 +253,13 @@ def check_params(req:dict): if (text_lang in [None, ""]) : return JSONResponse(status_code=400, content={"message": "text_lang is required"}) elif text_lang.lower() not in tts_config.languages: - return JSONResponse(status_code=400, content={"message": "text_lang is not supported"}) + return JSONResponse(status_code=400, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}) if (prompt_lang in [None, ""]) : return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) elif prompt_lang.lower() not in tts_config.languages: - return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"}) + return JSONResponse(status_code=400, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}) if media_type not in ["wav", "raw", "ogg", "aac"]: - return JSONResponse(status_code=400, content={"message": "media_type is not supported"}) + return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) elif media_type == "ogg" and not streaming_mode: return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})