Merge branch 'fast_inference' of https://github.com/SapphireLab/GPT-SoVITS-M into fast_inference_

This commit is contained in:
chasonjiang 2024-03-11 17:19:22 +08:00
commit 38dca77477
2 changed files with 18 additions and 21 deletions

View File

@ -100,7 +100,6 @@ class TTS_Config:
return configs return configs
def save_configs(self, configs_path:str=None)->None: def save_configs(self, configs_path:str=None)->None:
configs={ configs={
"default": { "default": {
@ -112,7 +111,15 @@ class TTS_Config:
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True "flash_attn_enabled": True
}, },
"custom": { "custom": self.update_configs()
}
if configs_path is None:
configs_path = self.configs_path
with open(configs_path, 'w') as f:
yaml.dump(configs, f)
def update_configs(self):
config = {
"device" : str(self.device), "device" : str(self.device),
"is_half" : self.is_half, "is_half" : self.is_half,
"t2s_weights_path" : self.t2s_weights_path, "t2s_weights_path" : self.t2s_weights_path,
@ -121,23 +128,14 @@ class TTS_Config:
"cnhuhbert_base_path": self.cnhuhbert_base_path, "cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled" : self.flash_attn_enabled "flash_attn_enabled" : self.flash_attn_enabled
} }
} return config
if configs_path is None:
configs_path = self.configs_path
with open(configs_path, 'w') as f:
yaml.dump(configs, f)
def __str__(self): def __str__(self):
string = "----------------TTS Config--------------\n" self.configs = self.update_configs()
string += "device: {}\n".format(self.device) string = "TTS Config".center(100, '-') + '\n'
string += "is_half: {}\n".format(self.is_half) for k, v in self.configs.items():
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled) string += f"{str(k).ljust(20)}: {str(v)}\n"
string += "bert_base_path: {}\n".format(self.bert_base_path) string += "-" * 100 + '\n'
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
string += "----------------------------------------\n"
return string return string

View File

@ -158,7 +158,6 @@ class TextPreprocessor:
bert_feature = torch.cat(bert_feature_list, dim=1) bert_feature = torch.cat(bert_feature_list, dim=1)
# phones = sum(phones_list, []) # phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list) norm_text = ''.join(norm_text_list)
return phones_list, bert_feature, norm_text return phones_list, bert_feature, norm_text