fix inference

This commit is contained in:
KakaruHayate 2025-02-12 22:49:46 +08:00
parent 23c9fb6311
commit 093304575e

View File

@ -1089,9 +1089,9 @@ class CFM(torch.nn.Module):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu,drop_audio_cond=False,drop_text=False).transpose(2, 1)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
if inference_cfg_rate>1e-5:
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, drop_audio_cond=True, drop_text=True).transpose(2, 1)
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
x = x + d * v_pred
t = t + d