修复了热切换模型时,精度不匹配导致的错误。

This commit is contained in:
chasonjiang 2024-03-12 16:08:50 +08:00
parent 511b99e4a9
commit 345f3203f8

View File

@ -206,7 +206,7 @@ class TTS:
self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path)
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
self.enable_half_precision(self.configs.is_half)
# self.enable_half_precision(self.configs.is_half)
@ -215,6 +215,8 @@ class TTS:
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model=self.cnhuhbert_model.eval()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
if self.configs.is_half:
self.cnhuhbert_model = self.cnhuhbert_model.half()
@ -224,6 +226,8 @@ class TTS:
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
self.bert_model=self.bert_model.eval()
self.bert_model = self.bert_model.to(self.configs.device)
if self.configs.is_half:
self.bert_model = self.bert_model.half()
@ -255,6 +259,8 @@ class TTS:
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
if self.configs.is_half:
self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str):
@ -271,6 +277,8 @@ class TTS:
t2s_model = t2s_model.to(self.configs.device)
t2s_model = t2s_model.eval()
self.t2s_model = t2s_model
if self.configs.is_half:
self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True):
'''