mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-23 21:19:47 +08:00
增加了TTS_Config类的健壮性
This commit is contained in:
parent
bfd7286068
commit
511b99e4a9
@ -1,3 +1,4 @@
|
|||||||
|
from copy import deepcopy
|
||||||
import math
|
import math
|
||||||
import os, sys
|
import os, sys
|
||||||
import random
|
import random
|
||||||
@ -50,18 +51,7 @@ custom:
|
|||||||
|
|
||||||
|
|
||||||
class TTS_Config:
|
class TTS_Config:
|
||||||
def __init__(self, configs: Union[dict, str]):
|
default_configs={
|
||||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
|
||||||
os.makedirs(configs_base_path, exist_ok=True)
|
|
||||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
|
||||||
if isinstance(configs, str):
|
|
||||||
self.configs_path = configs
|
|
||||||
configs:dict = self._load_configs(configs)
|
|
||||||
|
|
||||||
# assert isinstance(configs, dict)
|
|
||||||
self.default_configs:dict = configs.get("default", None)
|
|
||||||
if self.default_configs is None:
|
|
||||||
self.default_configs={
|
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"is_half": False,
|
"is_half": False,
|
||||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||||
@ -70,15 +60,54 @@ class TTS_Config:
|
|||||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||||
"flash_attn_enabled": True
|
"flash_attn_enabled": True
|
||||||
}
|
}
|
||||||
self.configs:dict = configs.get("custom", self.default_configs)
|
configs:dict = None
|
||||||
|
def __init__(self, configs: Union[dict, str]=None):
|
||||||
|
|
||||||
self.device = self.configs.get("device")
|
# 设置默认配置文件路径
|
||||||
self.is_half = self.configs.get("is_half")
|
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||||
self.t2s_weights_path = self.configs.get("t2s_weights_path")
|
os.makedirs(configs_base_path, exist_ok=True)
|
||||||
self.vits_weights_path = self.configs.get("vits_weights_path")
|
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||||
self.bert_base_path = self.configs.get("bert_base_path")
|
|
||||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
|
if configs in ["", None]:
|
||||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
|
if not os.path.exists(self.configs_path):
|
||||||
|
self.save_configs()
|
||||||
|
print(f"Create default config file at {self.configs_path}")
|
||||||
|
configs:dict = {"default": deepcopy(self.default_configs)}
|
||||||
|
|
||||||
|
if isinstance(configs, str):
|
||||||
|
self.configs_path = configs
|
||||||
|
configs:dict = self._load_configs(self.configs_path)
|
||||||
|
|
||||||
|
assert isinstance(configs, dict)
|
||||||
|
default_configs:dict = configs.get("default", None)
|
||||||
|
if default_configs is not None:
|
||||||
|
self.default_configs = default_configs
|
||||||
|
|
||||||
|
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
|
||||||
|
|
||||||
|
|
||||||
|
self.device = self.configs.get("device", torch.device("cpu"))
|
||||||
|
self.is_half = self.configs.get("is_half", False)
|
||||||
|
self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True)
|
||||||
|
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
||||||
|
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
||||||
|
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||||
|
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||||
|
|
||||||
|
|
||||||
|
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||||
|
self.t2s_weights_path = self.default_configs['t2s_weights_path']
|
||||||
|
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||||
|
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||||||
|
self.vits_weights_path = self.default_configs['vits_weights_path']
|
||||||
|
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||||||
|
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||||||
|
self.bert_base_path = self.default_configs['bert_base_path']
|
||||||
|
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||||||
|
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||||
|
self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path']
|
||||||
|
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||||
|
self.update_configs()
|
||||||
|
|
||||||
|
|
||||||
self.max_sec = None
|
self.max_sec = None
|
||||||
@ -92,7 +121,7 @@ class TTS_Config:
|
|||||||
self.n_speakers:int = 300
|
self.n_speakers:int = 300
|
||||||
|
|
||||||
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||||
print(self)
|
# print(self)
|
||||||
|
|
||||||
def _load_configs(self, configs_path: str)->dict:
|
def _load_configs(self, configs_path: str)->dict:
|
||||||
with open(configs_path, 'r') as f:
|
with open(configs_path, 'r') as f:
|
||||||
@ -102,24 +131,18 @@ class TTS_Config:
|
|||||||
|
|
||||||
def save_configs(self, configs_path:str=None)->None:
|
def save_configs(self, configs_path:str=None)->None:
|
||||||
configs={
|
configs={
|
||||||
"default": {
|
"default":self.default_configs,
|
||||||
"device": "cpu",
|
|
||||||
"is_half": False,
|
|
||||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
|
||||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
|
||||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
|
||||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
|
||||||
"flash_attn_enabled": True
|
|
||||||
},
|
|
||||||
"custom": self.update_configs()
|
|
||||||
}
|
}
|
||||||
|
if self.configs is not None:
|
||||||
|
configs["custom"] = self.update_configs()
|
||||||
|
|
||||||
if configs_path is None:
|
if configs_path is None:
|
||||||
configs_path = self.configs_path
|
configs_path = self.configs_path
|
||||||
with open(configs_path, 'w') as f:
|
with open(configs_path, 'w') as f:
|
||||||
yaml.dump(configs, f)
|
yaml.dump(configs, f)
|
||||||
|
|
||||||
def update_configs(self):
|
def update_configs(self):
|
||||||
config = {
|
self.config = {
|
||||||
"device" : str(self.device),
|
"device" : str(self.device),
|
||||||
"is_half" : self.is_half,
|
"is_half" : self.is_half,
|
||||||
"t2s_weights_path" : self.t2s_weights_path,
|
"t2s_weights_path" : self.t2s_weights_path,
|
||||||
@ -128,7 +151,7 @@ class TTS_Config:
|
|||||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||||
"flash_attn_enabled" : self.flash_attn_enabled
|
"flash_attn_enabled" : self.flash_attn_enabled
|
||||||
}
|
}
|
||||||
return config
|
return self.config
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
self.configs = self.update_configs()
|
self.configs = self.update_configs()
|
||||||
@ -137,6 +160,9 @@ class TTS_Config:
|
|||||||
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||||||
string += "-" * 100 + '\n'
|
string += "-" * 100 + '\n'
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
class TTS:
|
class TTS:
|
||||||
@ -253,7 +279,7 @@ class TTS:
|
|||||||
enable: bool, whether to enable half precision.
|
enable: bool, whether to enable half precision.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if self.configs.device == "cpu":
|
if self.configs.device == "cpu" and enable:
|
||||||
print("Half precision is not supported on CPU.")
|
print("Half precision is not supported on CPU.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -80,6 +80,7 @@ if cnhubert_base_path is not None:
|
|||||||
if bert_path is not None:
|
if bert_path is not None:
|
||||||
tts_config.bert_base_path = bert_path
|
tts_config.bert_base_path = bert_path
|
||||||
|
|
||||||
|
print(tts_config)
|
||||||
tts_pipline = TTS(tts_config)
|
tts_pipline = TTS(tts_config)
|
||||||
gpt_path = tts_config.t2s_weights_path
|
gpt_path = tts_config.t2s_weights_path
|
||||||
sovits_path = tts_config.vits_weights_path
|
sovits_path = tts_config.vits_weights_path
|
||||||
|
Loading…
x
Reference in New Issue
Block a user