mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-06 22:58:13 +08:00
Fix min-SNR loss weighting: replace min() with torch.clamp()
In VideoDiffusionLoss.__call__, the weight tensor `w` is a multi-dimensional tensor produced by append_dims(). Using Python's built-in min(w, self.min_snr_value) on a multi-dimensional tensor and a scalar does not perform element-wise clamping — it either raises an error or produces incorrect results depending on the tensor shape. Replace with torch.clamp(w, max=self.min_snr_value) to correctly apply element-wise upper-bound clamping, which is the intended behavior for the min-SNR-gamma loss weighting strategy.
This commit is contained in:
parent
7a1af71545
commit
9f548bce11
@ -121,7 +121,7 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
||||
|
||||
if self.min_snr_value is not None:
|
||||
w = min(w, self.min_snr_value)
|
||||
w = torch.clamp(w, max=self.min_snr_value)
|
||||
return self.get_loss(model_output, input, w)
|
||||
|
||||
def get_loss(self, model_output, target, w):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user