diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 694d4a7..61ba7be 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -206,7 +206,7 @@ class TTS: self.init_vits_weights(self.configs.vits_weights_path) self.init_bert_weights(self.configs.bert_base_path) self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) - self.enable_half_precision(self.configs.is_half) + # self.enable_half_precision(self.configs.is_half) @@ -215,6 +215,8 @@ class TTS: self.cnhuhbert_model = CNHubert(base_path) self.cnhuhbert_model=self.cnhuhbert_model.eval() self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) + if self.configs.is_half: + self.cnhuhbert_model = self.cnhuhbert_model.half() @@ -224,6 +226,8 @@ class TTS: self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) self.bert_model=self.bert_model.eval() self.bert_model = self.bert_model.to(self.configs.device) + if self.configs.is_half: + self.bert_model = self.bert_model.half() @@ -255,6 +259,8 @@ class TTS: vits_model = vits_model.eval() vits_model.load_state_dict(dict_s2["weight"], strict=False) self.vits_model = vits_model + if self.configs.is_half: + self.vits_model = self.vits_model.half() def init_t2s_weights(self, weights_path: str): @@ -271,6 +277,8 @@ class TTS: t2s_model = t2s_model.to(self.configs.device) t2s_model = t2s_model.eval() self.t2s_model = t2s_model + if self.configs.is_half: + self.t2s_model = self.t2s_model.half() def enable_half_precision(self, enable: bool = True): '''