mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-07 15:33:29 +08:00
优化TTS_Config的代码逻辑 (#2536)
* 优化TTS_Config的代码逻辑 * 在载入vits权重之后保存tts_config
This commit is contained in:
parent
cefafee32c
commit
b9211657d8
@ -304,10 +304,10 @@ class TTS_Config:
|
||||
configs: dict = self._load_configs(self.configs_path)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
version = configs.get("version", "v2").lower()
|
||||
assert version in ["v1", "v2", "v3", "v4", "v2pro", "v2proplus"]
|
||||
self.default_configs[version] = configs.get(version, self.default_configs[version])
|
||||
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
|
||||
configs_ = deepcopy(self.default_configs)
|
||||
configs_.update(configs)
|
||||
self.configs: dict = configs_.get("custom", configs_["v2"])
|
||||
self.default_configs = deepcopy(configs_)
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
||||
@ -315,11 +315,13 @@ class TTS_Config:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
# if str(self.device) == "cpu" and self.is_half:
|
||||
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
||||
# self.is_half = False
|
||||
if str(self.device) == "cpu" and self.is_half:
|
||||
print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
||||
self.is_half = False
|
||||
|
||||
version = self.configs.get("version", None)
|
||||
self.version = version
|
||||
assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!"
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||
@ -576,6 +578,10 @@ class TTS:
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.vits_model = self.vits_model.half()
|
||||
|
||||
self.configs.save_configs()
|
||||
|
||||
|
||||
|
||||
def init_t2s_weights(self, weights_path: str):
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
|
@ -1,4 +1,3 @@
|
||||
version: v2ProPlus
|
||||
custom:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
|
@ -125,7 +125,8 @@ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
||||
tts_config.device = device
|
||||
tts_config.is_half = is_half
|
||||
tts_config.version = version
|
||||
# tts_config.version = version
|
||||
tts_config.update_version(version)
|
||||
if gpt_path is not None:
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
|
Loading…
x
Reference in New Issue
Block a user