增加了TTS_Config类的健壮性

This commit is contained in:
chasonjiang 2024-03-12 15:30:08 +08:00
parent bfd7286068
commit 511b99e4a9
2 changed files with 61 additions and 34 deletions

View File

@ -1,3 +1,4 @@
from copy import deepcopy
import math
import os, sys
import random
@ -50,18 +51,7 @@ custom:
class TTS_Config:
def __init__(self, configs: Union[dict, str]):
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(configs)
# assert isinstance(configs, dict)
self.default_configs:dict = configs.get("default", None)
if self.default_configs is None:
self.default_configs={
default_configs={
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
@ -70,15 +60,54 @@ class TTS_Config:
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
}
self.configs:dict = configs.get("custom", self.default_configs)
configs:dict = None
def __init__(self, configs: Union[dict, str]=None):
self.device = self.configs.get("device")
self.is_half = self.configs.get("is_half")
self.t2s_weights_path = self.configs.get("t2s_weights_path")
self.vits_weights_path = self.configs.get("vits_weights_path")
self.bert_base_path = self.configs.get("bert_base_path")
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
# 设置默认配置文件路径
configs_base_path:str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
if configs in ["", None]:
if not os.path.exists(self.configs_path):
self.save_configs()
print(f"Create default config file at {self.configs_path}")
configs:dict = {"default": deepcopy(self.default_configs)}
if isinstance(configs, str):
self.configs_path = configs
configs:dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None)
if default_configs is not None:
self.default_configs = default_configs
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
self.device = self.configs.get("device", torch.device("cpu"))
self.is_half = self.configs.get("is_half", False)
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
self.vits_weights_path = self.default_configs['vits_weights_path']
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
self.bert_base_path = self.default_configs['bert_base_path']
print(f"fall back to default bert_base_path: {self.bert_base_path}")
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs()
self.max_sec = None
@ -92,7 +121,7 @@ class TTS_Config:
self.n_speakers:int = 300
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
print(self)
# print(self)
def _load_configs(self, configs_path: str)->dict:
with open(configs_path, 'r') as f:
@ -102,24 +131,18 @@ class TTS_Config:
def save_configs(self, configs_path:str=None)->None:
configs={
"default": {
"device": "cpu",
"is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
},
"custom": self.update_configs()
"default":self.default_configs,
}
if self.configs is not None:
configs["custom"] = self.update_configs()
if configs_path is None:
configs_path = self.configs_path
with open(configs_path, 'w') as f:
yaml.dump(configs, f)
def update_configs(self):
config = {
self.config = {
"device" : str(self.device),
"is_half" : self.is_half,
"t2s_weights_path" : self.t2s_weights_path,
@ -128,7 +151,7 @@ class TTS_Config:
"cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled" : self.flash_attn_enabled
}
return config
return self.config
def __str__(self):
self.configs = self.update_configs()
@ -137,6 +160,9 @@ class TTS_Config:
string += f"{str(k).ljust(20)}: {str(v)}\n"
string += "-" * 100 + '\n'
return string
def __repr__(self):
return self.__str__()
class TTS:
@ -253,7 +279,7 @@ class TTS:
enable: bool, whether to enable half precision.
'''
if self.configs.device == "cpu":
if self.configs.device == "cpu" and enable:
print("Half precision is not supported on CPU.")
return

View File

@ -80,6 +80,7 @@ if cnhubert_base_path is not None:
if bert_path is not None:
tts_config.bert_base_path = bert_path
print(tts_config)
tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path