优化TTS_Config的代码逻辑

This commit is contained in:
ChasonJiang 2025-07-18 11:38:13 +08:00
parent cefafee32c
commit 5cf55434eb
3 changed files with 12 additions and 10 deletions

View File

@ -304,10 +304,10 @@ class TTS_Config:
configs: dict = self._load_configs(self.configs_path) configs: dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict) assert isinstance(configs, dict)
version = configs.get("version", "v2").lower() configs_ = deepcopy(self.default_configs)
assert version in ["v1", "v2", "v3", "v4", "v2pro", "v2proplus"] configs_.update(configs)
self.default_configs[version] = configs.get(version, self.default_configs[version]) self.configs: dict = configs_.get("custom", configs_["v2"])
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available(): if "cuda" in str(self.device) and not torch.cuda.is_available():
@ -315,11 +315,13 @@ class TTS_Config:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half: if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.") print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False self.is_half = False
version = self.configs.get("version", None)
self.version = version self.version = version
assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!"
self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None) self.bert_base_path = self.configs.get("bert_base_path", None)

View File

@ -1,4 +1,3 @@
version: v2ProPlus
custom: custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base

View File

@ -125,7 +125,8 @@ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device tts_config.device = device
tts_config.is_half = is_half tts_config.is_half = is_half
tts_config.version = version # tts_config.version = version
tts_config.update_version(version)
if gpt_path is not None: if gpt_path is not None:
if "" in gpt_path or "!" in gpt_path: if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path] gpt_path = name2gpt_path[gpt_path]