Fix min-SNR loss weighting: replace min() with torch.clamp()

In VideoDiffusionLoss.__call__, the weight tensor `w` is a
multi-dimensional tensor produced by append_dims(). Using Python's
built-in min(w, self.min_snr_value) on a multi-dimensional tensor
and a scalar does not perform element-wise clamping — it either
raises an error or produces incorrect results depending on the
tensor shape.

Replace with torch.clamp(w, max=self.min_snr_value) to correctly
apply element-wise upper-bound clamping, which is the intended
behavior for the min-SNR-gamma loss weighting strategy.
This commit is contained in:
Mr-Neutr0n 2026-02-11 18:00:28 +05:30
parent 7a1af71545
commit 9f548bce11

View File

@ -121,7 +121,7 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
if self.min_snr_value is not None:
w = min(w, self.min_snr_value)
w = torch.clamp(w, max=self.min_snr_value)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):