mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
31 lines
871 B
Python
31 lines
871 B
Python
from abc import abstractmethod
|
|
from typing import Any, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ....modules.distributions.distributions import DiagonalGaussianDistribution
|
|
from .base import AbstractRegularizer
|
|
|
|
|
|
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
|
def __init__(self, sample: bool = True):
|
|
super().__init__()
|
|
self.sample = sample
|
|
|
|
def get_trainable_parameters(self) -> Any:
|
|
yield from ()
|
|
|
|
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
|
log = dict()
|
|
posterior = DiagonalGaussianDistribution(z)
|
|
if self.sample:
|
|
z = posterior.sample()
|
|
else:
|
|
z = posterior.mode()
|
|
kl_loss = posterior.kl()
|
|
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
|
log["kl_loss"] = kl_loss
|
|
return z, log
|