diff --git a/sat/sgm/modules/diffusionmodules/loss.py b/sat/sgm/modules/diffusionmodules/loss.py index 589e441..a237e3a 100644 --- a/sat/sgm/modules/diffusionmodules/loss.py +++ b/sat/sgm/modules/diffusionmodules/loss.py @@ -121,7 +121,7 @@ class VideoDiffusionLoss(StandardDiffusionLoss): w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred if self.min_snr_value is not None: - w = min(w, self.min_snr_value) + w = torch.clamp(w, max=self.min_snr_value) return self.get_loss(model_output, input, w) def get_loss(self, model_output, target, w):