mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-04 07:09:17 +08:00
增加了TTS_Config类的健壮性
This commit is contained in:
parent
bfd7286068
commit
511b99e4a9
@ -1,3 +1,4 @@
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import os, sys
|
||||
import random
|
||||
@ -50,18 +51,7 @@ custom:
|
||||
|
||||
|
||||
class TTS_Config:
|
||||
def __init__(self, configs: Union[dict, str]):
|
||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(configs)
|
||||
|
||||
# assert isinstance(configs, dict)
|
||||
self.default_configs:dict = configs.get("default", None)
|
||||
if self.default_configs is None:
|
||||
self.default_configs={
|
||||
default_configs={
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
@ -70,15 +60,54 @@ class TTS_Config:
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
}
|
||||
self.configs:dict = configs.get("custom", self.default_configs)
|
||||
configs:dict = None
|
||||
def __init__(self, configs: Union[dict, str]=None):
|
||||
|
||||
self.device = self.configs.get("device")
|
||||
self.is_half = self.configs.get("is_half")
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path")
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path")
|
||||
self.bert_base_path = self.configs.get("bert_base_path")
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
|
||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
|
||||
# 设置默认配置文件路径
|
||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
|
||||
if configs in ["", None]:
|
||||
if not os.path.exists(self.configs_path):
|
||||
self.save_configs()
|
||||
print(f"Create default config file at {self.configs_path}")
|
||||
configs:dict = {"default": deepcopy(self.default_configs)}
|
||||
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(self.configs_path)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
default_configs:dict = configs.get("default", None)
|
||||
if default_configs is not None:
|
||||
self.default_configs = default_configs
|
||||
|
||||
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
|
||||
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
|
||||
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||
|
||||
|
||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||
self.t2s_weights_path = self.default_configs['t2s_weights_path']
|
||||
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||||
self.vits_weights_path = self.default_configs['vits_weights_path']
|
||||
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||||
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||||
self.bert_base_path = self.default_configs['bert_base_path']
|
||||
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||||
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
|
||||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||
self.update_configs()
|
||||
|
||||
|
||||
self.max_sec = None
|
||||
@ -92,7 +121,7 @@ class TTS_Config:
|
||||
self.n_speakers:int = 300
|
||||
|
||||
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
print(self)
|
||||
# print(self)
|
||||
|
||||
def _load_configs(self, configs_path: str)->dict:
|
||||
with open(configs_path, 'r') as f:
|
||||
@ -102,24 +131,18 @@ class TTS_Config:
|
||||
|
||||
def save_configs(self, configs_path:str=None)->None:
|
||||
configs={
|
||||
"default": {
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
},
|
||||
"custom": self.update_configs()
|
||||
"default":self.default_configs,
|
||||
}
|
||||
if self.configs is not None:
|
||||
configs["custom"] = self.update_configs()
|
||||
|
||||
if configs_path is None:
|
||||
configs_path = self.configs_path
|
||||
with open(configs_path, 'w') as f:
|
||||
yaml.dump(configs, f)
|
||||
|
||||
def update_configs(self):
|
||||
config = {
|
||||
self.config = {
|
||||
"device" : str(self.device),
|
||||
"is_half" : self.is_half,
|
||||
"t2s_weights_path" : self.t2s_weights_path,
|
||||
@ -128,7 +151,7 @@ class TTS_Config:
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||
"flash_attn_enabled" : self.flash_attn_enabled
|
||||
}
|
||||
return config
|
||||
return self.config
|
||||
|
||||
def __str__(self):
|
||||
self.configs = self.update_configs()
|
||||
@ -137,6 +160,9 @@ class TTS_Config:
|
||||
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||||
string += "-" * 100 + '\n'
|
||||
return string
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class TTS:
|
||||
@ -253,7 +279,7 @@ class TTS:
|
||||
enable: bool, whether to enable half precision.
|
||||
|
||||
'''
|
||||
if self.configs.device == "cpu":
|
||||
if self.configs.device == "cpu" and enable:
|
||||
print("Half precision is not supported on CPU.")
|
||||
return
|
||||
|
||||
|
@ -80,6 +80,7 @@ if cnhubert_base_path is not None:
|
||||
if bert_path is not None:
|
||||
tts_config.bert_base_path = bert_path
|
||||
|
||||
print(tts_config)
|
||||
tts_pipline = TTS(tts_config)
|
||||
gpt_path = tts_config.t2s_weights_path
|
||||
sovits_path = tts_config.vits_weights_path
|
||||
|
Loading…
x
Reference in New Issue
Block a user