Fix AttributeError when prompt_cache['refer_spec'][0] is a tuple

This commit is contained in:
tzrain 2025-06-05 05:26:23 +08:00
parent 31c0cdd640
commit 13f13b2e7f

View File

@ -1407,7 +1407,10 @@ class TTS:
): ):
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) raw_entry = self.prompt_cache["refer_spec"][0]
if isinstance(raw_entry, tuple):
raw_entry = raw_entry[0]
refer_audio_spec = raw_entry.to(dtype=self.precision,device=self.configs.device)
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
@ -1474,7 +1477,10 @@ class TTS:
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device) raw_entry = self.prompt_cache["refer_spec"][0]
if isinstance(raw_entry, tuple):
raw_entry = raw_entry[0]
refer_audio_spec = raw_entry.to(dtype=self.precision,device=self.configs.device)
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]