From 9f548bce116de35a2e658d5e6a18558df1a99ea6 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Wed, 11 Feb 2026 18:00:28 +0530 Subject: [PATCH] Fix min-SNR loss weighting: replace min() with torch.clamp() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In VideoDiffusionLoss.__call__, the weight tensor `w` is a multi-dimensional tensor produced by append_dims(). Using Python's built-in min(w, self.min_snr_value) on a multi-dimensional tensor and a scalar does not perform element-wise clamping — it either raises an error or produces incorrect results depending on the tensor shape. Replace with torch.clamp(w, max=self.min_snr_value) to correctly apply element-wise upper-bound clamping, which is the intended behavior for the min-SNR-gamma loss weighting strategy. --- sat/sgm/modules/diffusionmodules/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sat/sgm/modules/diffusionmodules/loss.py b/sat/sgm/modules/diffusionmodules/loss.py index 589e441..a237e3a 100644 --- a/sat/sgm/modules/diffusionmodules/loss.py +++ b/sat/sgm/modules/diffusionmodules/loss.py @@ -121,7 +121,7 @@ class VideoDiffusionLoss(StandardDiffusionLoss): w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred if self.min_snr_value is not None: - w = min(w, self.min_snr_value) + w = torch.clamp(w, max=self.min_snr_value) return self.get_loss(model_output, input, w) def get_loss(self, model_output, target, w):