增加健壮性,防止在cpu推理时设置半精度报错

This commit is contained in:
chasonjiang 2024-03-14 11:24:10 +08:00
parent 85dea3008b
commit b8ce03fd1b

View File

@ -228,7 +228,7 @@ class TTS:
self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model=self.cnhuhbert_model.eval()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
if self.configs.is_half:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.cnhuhbert_model = self.cnhuhbert_model.half()
@ -239,7 +239,7 @@ class TTS:
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
self.bert_model=self.bert_model.eval()
self.bert_model = self.bert_model.to(self.configs.device)
if self.configs.is_half:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.bert_model = self.bert_model.half()
@ -272,7 +272,7 @@ class TTS:
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model
if self.configs.is_half:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half()
@ -290,7 +290,7 @@ class TTS:
t2s_model = t2s_model.to(self.configs.device)
t2s_model = t2s_model.eval()
self.t2s_model = t2s_model
if self.configs.is_half:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True):
@ -300,7 +300,7 @@ class TTS:
enable: bool, whether to enable half precision.
'''
if self.configs.device == "cpu" and enable:
if str(self.configs.device) == "cpu" and enable:
print("Half precision is not supported on CPU.")
return