mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
增加健壮性,防止在cpu推理时设置半精度报错
This commit is contained in:
parent
85dea3008b
commit
b8ce03fd1b
@ -228,7 +228,7 @@ class TTS:
|
||||
self.cnhuhbert_model = CNHubert(base_path)
|
||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
|
||||
|
||||
@ -239,7 +239,7 @@ class TTS:
|
||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||
self.bert_model=self.bert_model.eval()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.bert_model = self.bert_model.half()
|
||||
|
||||
|
||||
@ -272,7 +272,7 @@ class TTS:
|
||||
vits_model = vits_model.eval()
|
||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
self.vits_model = vits_model
|
||||
if self.configs.is_half:
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.vits_model = self.vits_model.half()
|
||||
|
||||
|
||||
@ -290,7 +290,7 @@ class TTS:
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model = t2s_model.eval()
|
||||
self.t2s_model = t2s_model
|
||||
if self.configs.is_half:
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
def enable_half_precision(self, enable: bool = True):
|
||||
@ -300,7 +300,7 @@ class TTS:
|
||||
enable: bool, whether to enable half precision.
|
||||
|
||||
'''
|
||||
if self.configs.device == "cpu" and enable:
|
||||
if str(self.configs.device) == "cpu" and enable:
|
||||
print("Half precision is not supported on CPU.")
|
||||
return
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user