mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
增加健壮性,防止在cpu推理时设置半精度报错
This commit is contained in:
parent
85dea3008b
commit
b8ce03fd1b
@ -228,7 +228,7 @@ class TTS:
|
|||||||
self.cnhuhbert_model = CNHubert(base_path)
|
self.cnhuhbert_model = CNHubert(base_path)
|
||||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
|
|
||||||
|
|
||||||
@ -239,7 +239,7 @@ class TTS:
|
|||||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||||
self.bert_model=self.bert_model.eval()
|
self.bert_model=self.bert_model.eval()
|
||||||
self.bert_model = self.bert_model.to(self.configs.device)
|
self.bert_model = self.bert_model.to(self.configs.device)
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.bert_model = self.bert_model.half()
|
self.bert_model = self.bert_model.half()
|
||||||
|
|
||||||
|
|
||||||
@ -272,7 +272,7 @@ class TTS:
|
|||||||
vits_model = vits_model.eval()
|
vits_model = vits_model.eval()
|
||||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
self.vits_model = vits_model
|
self.vits_model = vits_model
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.vits_model = self.vits_model.half()
|
self.vits_model = self.vits_model.half()
|
||||||
|
|
||||||
|
|
||||||
@ -290,7 +290,7 @@ class TTS:
|
|||||||
t2s_model = t2s_model.to(self.configs.device)
|
t2s_model = t2s_model.to(self.configs.device)
|
||||||
t2s_model = t2s_model.eval()
|
t2s_model = t2s_model.eval()
|
||||||
self.t2s_model = t2s_model
|
self.t2s_model = t2s_model
|
||||||
if self.configs.is_half:
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
def enable_half_precision(self, enable: bool = True):
|
def enable_half_precision(self, enable: bool = True):
|
||||||
@ -300,7 +300,7 @@ class TTS:
|
|||||||
enable: bool, whether to enable half precision.
|
enable: bool, whether to enable half precision.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if self.configs.device == "cpu" and enable:
|
if str(self.configs.device) == "cpu" and enable:
|
||||||
print("Half precision is not supported on CPU.")
|
print("Half precision is not supported on CPU.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user