This commit is contained in:
ChasonJiang 2025-04-21 23:12:44 +08:00
parent 4bd2625e7c
commit cbc872a28b

View File

@ -656,8 +656,8 @@ class TTS:
self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.half()
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.half()
if self.bigvgan_model is not None: if self.vocoder is not None:
self.bigvgan_model = self.bigvgan_model.half() self.vocoder = self.vocoder.half()
else: else:
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.float() self.t2s_model = self.t2s_model.float()
@ -667,8 +667,8 @@ class TTS:
self.bert_model = self.bert_model.float() self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float() self.cnhuhbert_model = self.cnhuhbert_model.float()
if self.bigvgan_model is not None: if self.vocoder is not None:
self.bigvgan_model = self.bigvgan_model.float() self.vocoder = self.vocoder.float()
def set_device(self, device: torch.device, save: bool = True): def set_device(self, device: torch.device, save: bool = True):
""" """
@ -687,8 +687,8 @@ class TTS:
self.bert_model = self.bert_model.to(device) self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device) self.cnhuhbert_model = self.cnhuhbert_model.to(device)
if self.bigvgan_model is not None: if self.vocoder is not None:
self.bigvgan_model = self.bigvgan_model.to(device) self.vocoder = self.vocoder.to(device)
if self.sr_model is not None: if self.sr_model is not None:
self.sr_model = self.sr_model.to(device) self.sr_model = self.sr_model.to(device)
@ -1358,7 +1358,7 @@ class TTS:
return sr, audio return sr, audio
def useing_vocoder_synthesis( def using_vocoder_synthesis(
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32 self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
): ):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
@ -1420,7 +1420,7 @@ class TTS:
return audio return audio
def useing_vocoder_synthesis_batched_infer( def using_vocoder_synthesis_batched_infer(
self, self,
idx_list: List[int], idx_list: List[int],
semantic_tokens_list: List[torch.Tensor], semantic_tokens_list: List[torch.Tensor],