mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
310 lines
12 KiB
Python
310 lines
12 KiB
Python
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from einops import rearrange
|
|
from matplotlib import colormaps
|
|
from matplotlib import pyplot as plt
|
|
|
|
from ....util import default, instantiate_from_config
|
|
from ..lpips.loss.lpips import LPIPS
|
|
from ..lpips.model.model import weights_init
|
|
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
|
|
|
|
|
class GeneralLPIPSWithDiscriminator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
disc_start: int,
|
|
logvar_init: float = 0.0,
|
|
disc_num_layers: int = 3,
|
|
disc_in_channels: int = 3,
|
|
disc_factor: float = 1.0,
|
|
disc_weight: float = 1.0,
|
|
perceptual_weight: float = 1.0,
|
|
disc_loss: str = "hinge",
|
|
scale_input_to_tgt_size: bool = False,
|
|
dims: int = 2,
|
|
learn_logvar: bool = False,
|
|
regularization_weights: Union[None, Dict[str, float]] = None,
|
|
additional_log_keys: Optional[List[str]] = None,
|
|
discriminator_config: Optional[Dict] = None,
|
|
):
|
|
super().__init__()
|
|
self.dims = dims
|
|
if self.dims > 2:
|
|
print(
|
|
f"running with dims={dims}. This means that for perceptual loss "
|
|
f"calculation, the LPIPS loss will be applied to each frame "
|
|
f"independently."
|
|
)
|
|
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
|
assert disc_loss in ["hinge", "vanilla"]
|
|
self.perceptual_loss = LPIPS().eval()
|
|
self.perceptual_weight = perceptual_weight
|
|
# output log variance
|
|
self.logvar = nn.Parameter(torch.full((), logvar_init), requires_grad=learn_logvar)
|
|
self.learn_logvar = learn_logvar
|
|
|
|
discriminator_config = default(
|
|
discriminator_config,
|
|
{
|
|
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
|
"params": {
|
|
"input_nc": disc_in_channels,
|
|
"n_layers": disc_num_layers,
|
|
"use_actnorm": False,
|
|
},
|
|
},
|
|
)
|
|
|
|
self.discriminator = instantiate_from_config(discriminator_config).apply(weights_init)
|
|
self.discriminator_iter_start = disc_start
|
|
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
|
self.disc_factor = disc_factor
|
|
self.discriminator_weight = disc_weight
|
|
self.regularization_weights = default(regularization_weights, {})
|
|
|
|
self.forward_keys = [
|
|
"optimizer_idx",
|
|
"global_step",
|
|
"last_layer",
|
|
"split",
|
|
"regularization_log",
|
|
]
|
|
|
|
self.additional_log_keys = set(default(additional_log_keys, []))
|
|
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
|
|
|
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
|
return self.discriminator.parameters()
|
|
|
|
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
|
if self.learn_logvar:
|
|
yield self.logvar
|
|
yield from ()
|
|
|
|
@torch.no_grad()
|
|
def log_images(
|
|
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
|
) -> Dict[str, torch.Tensor]:
|
|
# calc logits of real/fake
|
|
logits_real = self.discriminator(inputs.contiguous().detach())
|
|
if len(logits_real.shape) < 4:
|
|
# Non patch-discriminator
|
|
return dict()
|
|
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
|
# -> (b, 1, h, w)
|
|
|
|
# parameters for colormapping
|
|
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
|
cmap = colormaps["PiYG"] # diverging colormap
|
|
|
|
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
|
"""(b, 1, ...) -> (b, 3, ...)"""
|
|
logits = (logits + high) / (2 * high)
|
|
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
|
# -> (b, 1, ..., 3)
|
|
logits = torch.from_numpy(logits_np).to(logits.device)
|
|
return rearrange(logits, "b 1 ... c -> b c ...")
|
|
|
|
logits_real = torch.nn.functional.interpolate(
|
|
logits_real,
|
|
size=inputs.shape[-2:],
|
|
mode="nearest",
|
|
antialias=False,
|
|
)
|
|
logits_fake = torch.nn.functional.interpolate(
|
|
logits_fake,
|
|
size=reconstructions.shape[-2:],
|
|
mode="nearest",
|
|
antialias=False,
|
|
)
|
|
|
|
# alpha value of logits for overlay
|
|
alpha_real = torch.abs(logits_real) / high
|
|
alpha_fake = torch.abs(logits_fake) / high
|
|
# -> (b, 1, h, w) in range [0, 0.5]
|
|
# alpha value of lines don't really matter, since the values are the same
|
|
# for both images and logits anyway
|
|
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
|
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
|
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
|
# -> (1, h, w)
|
|
# blend logits and images together
|
|
|
|
# prepare logits for plotting
|
|
logits_real = to_colormap(logits_real)
|
|
logits_fake = to_colormap(logits_fake)
|
|
# resize logits
|
|
# -> (b, 3, h, w)
|
|
|
|
# make some grids
|
|
# add all logits to one plot
|
|
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
|
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
|
# I just love how torchvision calls the number of columns `nrow`
|
|
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
|
# -> (3, h, w)
|
|
|
|
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
|
grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions + 0.5, nrow=4)
|
|
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
|
# -> (3, h, w) in range [0, 1]
|
|
|
|
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
|
|
|
# Create labeled colorbar
|
|
dpi = 100
|
|
height = 128 / dpi
|
|
width = grid_logits.shape[2] / dpi
|
|
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
|
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
|
plt.colorbar(
|
|
img,
|
|
cax=ax,
|
|
orientation="horizontal",
|
|
fraction=0.9,
|
|
aspect=width / height,
|
|
pad=0.0,
|
|
)
|
|
img.set_visible(False)
|
|
fig.tight_layout()
|
|
fig.canvas.draw()
|
|
# manually convert figure to numpy
|
|
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
|
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
|
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
|
|
|
# Add colorbar to plot
|
|
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
|
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
|
return {
|
|
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
|
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
|
}
|
|
|
|
def calculate_adaptive_weight(
|
|
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
|
) -> torch.Tensor:
|
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
|
|
|
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
|
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
|
d_weight = d_weight * self.discriminator_weight
|
|
return d_weight
|
|
|
|
def forward(
|
|
self,
|
|
inputs: torch.Tensor,
|
|
reconstructions: torch.Tensor,
|
|
*, # added because I changed the order here
|
|
regularization_log: Dict[str, torch.Tensor],
|
|
optimizer_idx: int,
|
|
global_step: int,
|
|
last_layer: torch.Tensor,
|
|
split: str = "train",
|
|
weights: Union[None, float, torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, dict]:
|
|
if self.scale_input_to_tgt_size:
|
|
inputs = torch.nn.functional.interpolate(
|
|
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
|
)
|
|
|
|
if self.dims > 2:
|
|
inputs, reconstructions = map(
|
|
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
|
(inputs, reconstructions),
|
|
)
|
|
|
|
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
|
if self.perceptual_weight > 0:
|
|
frame_indices = torch.randn((inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices
|
|
|
|
from sgm.modules.autoencoding.losses.video_loss import pick_video_frame
|
|
|
|
input_frames = pick_video_frame(inputs, frame_indices)
|
|
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
|
|
|
p_loss = self.perceptual_loss(
|
|
input_frames.contiguous(), recon_frames.contiguous()
|
|
).mean()
|
|
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
|
|
|
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
|
|
|
# now the GAN part
|
|
if optimizer_idx == 0:
|
|
# generator update
|
|
if global_step >= self.discriminator_iter_start or not self.training:
|
|
logits_fake = self.discriminator(reconstructions.contiguous())
|
|
g_loss = -torch.mean(logits_fake)
|
|
if self.training:
|
|
d_weight = self.calculate_adaptive_weight(
|
|
nll_loss, g_loss, last_layer=last_layer
|
|
)
|
|
else:
|
|
d_weight = torch.tensor(1.0)
|
|
else:
|
|
d_weight = torch.tensor(0.0)
|
|
g_loss = torch.tensor(0.0, requires_grad=True)
|
|
|
|
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
|
log = dict()
|
|
for k in regularization_log:
|
|
if k in self.regularization_weights:
|
|
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
|
if k in self.additional_log_keys:
|
|
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
|
|
|
log.update(
|
|
{
|
|
f"{split}/loss/total": loss.clone().detach().mean(),
|
|
f"{split}/loss/nll": nll_loss.detach().mean(),
|
|
f"{split}/loss/rec": rec_loss.detach().mean(),
|
|
f"{split}/loss/percep": p_loss.detach().mean(),
|
|
f"{split}/loss/rec": rec_loss.detach().mean(),
|
|
f"{split}/loss/g": g_loss.detach().mean(),
|
|
f"{split}/scalars/logvar": self.logvar.detach(),
|
|
f"{split}/scalars/d_weight": d_weight.detach(),
|
|
}
|
|
)
|
|
|
|
return loss, log
|
|
elif optimizer_idx == 1:
|
|
# second pass for discriminator update
|
|
logits_real = self.discriminator(inputs.contiguous().detach())
|
|
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
|
|
|
if global_step >= self.discriminator_iter_start or not self.training:
|
|
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
|
else:
|
|
d_loss = torch.tensor(0.0, requires_grad=True)
|
|
|
|
log = {
|
|
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
|
f"{split}/logits/real": logits_real.detach().mean(),
|
|
f"{split}/logits/fake": logits_fake.detach().mean(),
|
|
}
|
|
return d_loss, log
|
|
else:
|
|
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
|
|
|
def get_nll_loss(
|
|
self,
|
|
rec_loss: torch.Tensor,
|
|
weights: Optional[Union[float, torch.Tensor]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
|
weighted_nll_loss = nll_loss
|
|
if weights is not None:
|
|
weighted_nll_loss = weights * nll_loss
|
|
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
|
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
|
|
|
return nll_loss, weighted_nll_loss
|