mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
737 lines
24 KiB
Python
737 lines
24 KiB
Python
from typing import Any, Union
|
|
from math import log2
|
|
from beartype import beartype
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.autograd import grad as torch_grad
|
|
from torch.cuda.amp import autocast
|
|
|
|
import torchvision
|
|
from torchvision.models import VGG16_Weights
|
|
from einops import rearrange, einsum, repeat
|
|
from einops.layers.torch import Rearrange
|
|
from kornia.filters import filter3d
|
|
|
|
from ..magvit2_pytorch import Residual, FeedForward, LinearSpaceAttention
|
|
from .lpips import LPIPS
|
|
|
|
from sgm.modules.autoencoding.vqvae.movq_enc_3d import CausalConv3d, DownSample3D
|
|
from sgm.util import instantiate_from_config
|
|
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
|
|
def pair(t):
|
|
return t if isinstance(t, tuple) else (t, t)
|
|
|
|
|
|
def leaky_relu(p=0.1):
|
|
return nn.LeakyReLU(p)
|
|
|
|
|
|
def hinge_discr_loss(fake, real):
|
|
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
|
|
|
|
|
def hinge_gen_loss(fake):
|
|
return -fake.mean()
|
|
|
|
|
|
@autocast(enabled=False)
|
|
@beartype
|
|
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
|
return torch_grad(
|
|
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
|
|
)[0].detach()
|
|
|
|
|
|
def pick_video_frame(video, frame_indices):
|
|
batch, device = video.shape[0], video.device
|
|
video = rearrange(video, "b c f ... -> b f c ...")
|
|
batch_indices = torch.arange(batch, device=device)
|
|
batch_indices = rearrange(batch_indices, "b -> b 1")
|
|
images = video[batch_indices, frame_indices]
|
|
images = rearrange(images, "b 1 c ... -> b c ...")
|
|
return images
|
|
|
|
|
|
def gradient_penalty(images, output):
|
|
batch_size = images.shape[0]
|
|
|
|
gradients = torch_grad(
|
|
outputs=output,
|
|
inputs=images,
|
|
grad_outputs=torch.ones(output.size(), device=images.device),
|
|
create_graph=True,
|
|
retain_graph=True,
|
|
only_inputs=True,
|
|
)[0]
|
|
|
|
gradients = rearrange(gradients, "b ... -> b (...)")
|
|
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
|
|
|
|
|
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
|
|
|
|
|
|
class Blur(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
f = torch.Tensor([1, 2, 1])
|
|
self.register_buffer("f", f)
|
|
|
|
def forward(self, x, space_only=False, time_only=False):
|
|
assert not (space_only and time_only)
|
|
|
|
f = self.f
|
|
|
|
if space_only:
|
|
f = einsum("i, j -> i j", f, f)
|
|
f = rearrange(f, "... -> 1 1 ...")
|
|
elif time_only:
|
|
f = rearrange(f, "f -> 1 f 1 1")
|
|
else:
|
|
f = einsum("i, j, k -> i j k", f, f, f)
|
|
f = rearrange(f, "... -> 1 ...")
|
|
|
|
is_images = x.ndim == 4
|
|
|
|
if is_images:
|
|
x = rearrange(x, "b c h w -> b c 1 h w")
|
|
|
|
out = filter3d(x, f, normalized=True)
|
|
|
|
if is_images:
|
|
out = rearrange(out, "b c 1 h w -> b c h w")
|
|
|
|
return out
|
|
|
|
|
|
class DiscriminatorBlock(nn.Module):
|
|
def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True):
|
|
super().__init__()
|
|
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(input_channels, filters, 3, padding=1),
|
|
leaky_relu(),
|
|
nn.Conv2d(filters, filters, 3, padding=1),
|
|
leaky_relu(),
|
|
)
|
|
|
|
self.maybe_blur = Blur() if antialiased_downsample else None
|
|
|
|
self.downsample = (
|
|
nn.Sequential(
|
|
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
|
nn.Conv2d(filters * 4, filters, 1),
|
|
)
|
|
if downsample
|
|
else None
|
|
)
|
|
|
|
def forward(self, x):
|
|
res = self.conv_res(x)
|
|
|
|
x = self.net(x)
|
|
|
|
if exists(self.downsample):
|
|
if exists(self.maybe_blur):
|
|
x = self.maybe_blur(x, space_only=True)
|
|
|
|
x = self.downsample(x)
|
|
|
|
x = (x + res) * (2**-0.5)
|
|
return x
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
image_size,
|
|
channels=3,
|
|
max_dim=512,
|
|
attn_heads=8,
|
|
attn_dim_head=32,
|
|
linear_attn_dim_head=8,
|
|
linear_attn_heads=16,
|
|
ff_mult=4,
|
|
antialiased_downsample=False,
|
|
):
|
|
super().__init__()
|
|
image_size = pair(image_size)
|
|
min_image_resolution = min(image_size)
|
|
|
|
num_layers = int(log2(min_image_resolution) - 2)
|
|
|
|
blocks = []
|
|
|
|
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
|
|
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
|
|
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
|
|
|
|
blocks = []
|
|
attn_blocks = []
|
|
|
|
image_resolution = min_image_resolution
|
|
|
|
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
|
|
num_layer = ind + 1
|
|
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
|
|
|
block = DiscriminatorBlock(
|
|
in_chan,
|
|
out_chan,
|
|
downsample=is_not_last,
|
|
antialiased_downsample=antialiased_downsample,
|
|
)
|
|
|
|
attn_block = nn.Sequential(
|
|
Residual(
|
|
LinearSpaceAttention(
|
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
|
)
|
|
),
|
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
|
)
|
|
|
|
blocks.append(nn.ModuleList([block, attn_block]))
|
|
|
|
image_resolution //= 2
|
|
|
|
self.blocks = nn.ModuleList(blocks)
|
|
|
|
dim_last = layer_dims[-1]
|
|
|
|
downsample_factor = 2**num_layers
|
|
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
|
|
|
|
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
|
|
|
|
self.to_logits = nn.Sequential(
|
|
nn.Conv2d(dim_last, dim_last, 3, padding=1),
|
|
leaky_relu(),
|
|
Rearrange("b ... -> b (...)"),
|
|
nn.Linear(latent_dim, 1),
|
|
Rearrange("b 1 -> b"),
|
|
)
|
|
|
|
def forward(self, x):
|
|
for block, attn_block in self.blocks:
|
|
x = block(x)
|
|
x = attn_block(x)
|
|
|
|
return self.to_logits(x)
|
|
|
|
|
|
class DiscriminatorBlock3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_channels,
|
|
filters,
|
|
antialiased_downsample=True,
|
|
):
|
|
super().__init__()
|
|
self.conv_res = nn.Conv3d(input_channels, filters, 1, stride=2)
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv3d(input_channels, filters, 3, padding=1),
|
|
leaky_relu(),
|
|
nn.Conv3d(filters, filters, 3, padding=1),
|
|
leaky_relu(),
|
|
)
|
|
|
|
self.maybe_blur = Blur() if antialiased_downsample else None
|
|
|
|
self.downsample = nn.Sequential(
|
|
Rearrange("b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w", p1=2, p2=2, p3=2),
|
|
nn.Conv3d(filters * 8, filters, 1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
res = self.conv_res(x)
|
|
|
|
x = self.net(x)
|
|
|
|
if exists(self.downsample):
|
|
if exists(self.maybe_blur):
|
|
x = self.maybe_blur(x, space_only=True)
|
|
|
|
x = self.downsample(x)
|
|
|
|
x = (x + res) * (2**-0.5)
|
|
return x
|
|
|
|
|
|
class DiscriminatorBlock3DWithfirstframe(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_channels,
|
|
filters,
|
|
antialiased_downsample=True,
|
|
pad_mode="first",
|
|
):
|
|
super().__init__()
|
|
self.downsample_res = DownSample3D(
|
|
in_channels=input_channels,
|
|
out_channels=filters,
|
|
with_conv=True,
|
|
compress_time=True,
|
|
)
|
|
|
|
self.net = nn.Sequential(
|
|
CausalConv3d(input_channels, filters, kernel_size=3, pad_mode=pad_mode),
|
|
leaky_relu(),
|
|
CausalConv3d(filters, filters, kernel_size=3, pad_mode=pad_mode),
|
|
leaky_relu(),
|
|
)
|
|
|
|
self.maybe_blur = Blur() if antialiased_downsample else None
|
|
|
|
self.downsample = DownSample3D(
|
|
in_channels=filters,
|
|
out_channels=filters,
|
|
with_conv=True,
|
|
compress_time=True,
|
|
)
|
|
|
|
def forward(self, x):
|
|
res = self.downsample_res(x)
|
|
|
|
x = self.net(x)
|
|
|
|
if exists(self.downsample):
|
|
if exists(self.maybe_blur):
|
|
x = self.maybe_blur(x, space_only=True)
|
|
|
|
x = self.downsample(x)
|
|
|
|
x = (x + res) * (2**-0.5)
|
|
return x
|
|
|
|
|
|
class Discriminator3D(nn.Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
image_size,
|
|
frame_num,
|
|
channels=3,
|
|
max_dim=512,
|
|
linear_attn_dim_head=8,
|
|
linear_attn_heads=16,
|
|
ff_mult=4,
|
|
antialiased_downsample=False,
|
|
):
|
|
super().__init__()
|
|
image_size = pair(image_size)
|
|
min_image_resolution = min(image_size)
|
|
|
|
num_layers = int(log2(min_image_resolution) - 2)
|
|
temporal_num_layers = int(log2(frame_num))
|
|
self.temporal_num_layers = temporal_num_layers
|
|
|
|
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
|
|
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
|
|
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
|
|
|
|
blocks = []
|
|
|
|
image_resolution = min_image_resolution
|
|
frame_resolution = frame_num
|
|
|
|
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
|
|
num_layer = ind + 1
|
|
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
|
|
|
if ind < temporal_num_layers:
|
|
block = DiscriminatorBlock3D(
|
|
in_chan,
|
|
out_chan,
|
|
antialiased_downsample=antialiased_downsample,
|
|
)
|
|
|
|
blocks.append(block)
|
|
|
|
frame_resolution //= 2
|
|
else:
|
|
block = DiscriminatorBlock(
|
|
in_chan,
|
|
out_chan,
|
|
downsample=is_not_last,
|
|
antialiased_downsample=antialiased_downsample,
|
|
)
|
|
attn_block = nn.Sequential(
|
|
Residual(
|
|
LinearSpaceAttention(
|
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
|
)
|
|
),
|
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
|
)
|
|
|
|
blocks.append(nn.ModuleList([block, attn_block]))
|
|
|
|
image_resolution //= 2
|
|
|
|
self.blocks = nn.ModuleList(blocks)
|
|
|
|
dim_last = layer_dims[-1]
|
|
|
|
downsample_factor = 2**num_layers
|
|
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
|
|
|
|
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
|
|
|
|
self.to_logits = nn.Sequential(
|
|
nn.Conv2d(dim_last, dim_last, 3, padding=1),
|
|
leaky_relu(),
|
|
Rearrange("b ... -> b (...)"),
|
|
nn.Linear(latent_dim, 1),
|
|
Rearrange("b 1 -> b"),
|
|
)
|
|
|
|
def forward(self, x):
|
|
for i, layer in enumerate(self.blocks):
|
|
if i < self.temporal_num_layers:
|
|
x = layer(x)
|
|
if i == self.temporal_num_layers - 1:
|
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
|
else:
|
|
block, attn_block = layer
|
|
x = block(x)
|
|
x = attn_block(x)
|
|
|
|
return self.to_logits(x)
|
|
|
|
|
|
class Discriminator3DWithfirstframe(nn.Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
image_size,
|
|
frame_num,
|
|
channels=3,
|
|
max_dim=512,
|
|
linear_attn_dim_head=8,
|
|
linear_attn_heads=16,
|
|
ff_mult=4,
|
|
antialiased_downsample=False,
|
|
):
|
|
super().__init__()
|
|
image_size = pair(image_size)
|
|
min_image_resolution = min(image_size)
|
|
|
|
num_layers = int(log2(min_image_resolution) - 2)
|
|
temporal_num_layers = int(log2(frame_num))
|
|
self.temporal_num_layers = temporal_num_layers
|
|
|
|
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
|
|
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
|
|
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
|
|
|
|
blocks = []
|
|
|
|
image_resolution = min_image_resolution
|
|
frame_resolution = frame_num
|
|
|
|
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
|
|
num_layer = ind + 1
|
|
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
|
|
|
if ind < temporal_num_layers:
|
|
block = DiscriminatorBlock3DWithfirstframe(
|
|
in_chan,
|
|
out_chan,
|
|
antialiased_downsample=antialiased_downsample,
|
|
)
|
|
|
|
blocks.append(block)
|
|
|
|
frame_resolution //= 2
|
|
else:
|
|
block = DiscriminatorBlock(
|
|
in_chan,
|
|
out_chan,
|
|
downsample=is_not_last,
|
|
antialiased_downsample=antialiased_downsample,
|
|
)
|
|
attn_block = nn.Sequential(
|
|
Residual(
|
|
LinearSpaceAttention(
|
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
|
)
|
|
),
|
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
|
)
|
|
|
|
blocks.append(nn.ModuleList([block, attn_block]))
|
|
|
|
image_resolution //= 2
|
|
|
|
self.blocks = nn.ModuleList(blocks)
|
|
|
|
dim_last = layer_dims[-1]
|
|
|
|
downsample_factor = 2**num_layers
|
|
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
|
|
|
|
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
|
|
|
|
self.to_logits = nn.Sequential(
|
|
nn.Conv2d(dim_last, dim_last, 3, padding=1),
|
|
leaky_relu(),
|
|
Rearrange("b ... -> b (...)"),
|
|
nn.Linear(latent_dim, 1),
|
|
Rearrange("b 1 -> b"),
|
|
)
|
|
|
|
def forward(self, x):
|
|
for i, layer in enumerate(self.blocks):
|
|
if i < self.temporal_num_layers:
|
|
x = layer(x)
|
|
if i == self.temporal_num_layers - 1:
|
|
x = x.mean(dim=2)
|
|
# x = rearrange(x, "b c f h w -> (b f) c h w")
|
|
else:
|
|
block, attn_block = layer
|
|
x = block(x)
|
|
x = attn_block(x)
|
|
|
|
return self.to_logits(x)
|
|
|
|
|
|
class VideoAutoencoderLoss(nn.Module):
|
|
def __init__(
|
|
self,
|
|
disc_start,
|
|
perceptual_weight=1,
|
|
adversarial_loss_weight=0,
|
|
multiscale_adversarial_loss_weight=0,
|
|
grad_penalty_loss_weight=0,
|
|
quantizer_aux_loss_weight=0,
|
|
vgg_weights=VGG16_Weights.DEFAULT,
|
|
discr_kwargs=None,
|
|
discr_3d_kwargs=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.disc_start = disc_start
|
|
self.perceptual_weight = perceptual_weight
|
|
self.adversarial_loss_weight = adversarial_loss_weight
|
|
self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
|
|
self.grad_penalty_loss_weight = grad_penalty_loss_weight
|
|
self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
|
|
|
|
if self.perceptual_weight > 0:
|
|
self.perceptual_model = LPIPS().eval()
|
|
# self.vgg = torchvision.models.vgg16(pretrained = True)
|
|
# self.vgg.requires_grad_(False)
|
|
# if self.adversarial_loss_weight > 0:
|
|
# self.discr = Discriminator(**discr_kwargs)
|
|
# else:
|
|
# self.discr = None
|
|
# if self.multiscale_adversarial_loss_weight > 0:
|
|
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
|
|
# else:
|
|
# self.multiscale_discrs = None
|
|
if discr_kwargs is not None:
|
|
self.discr = Discriminator(**discr_kwargs)
|
|
else:
|
|
self.discr = None
|
|
if discr_3d_kwargs is not None:
|
|
# self.discr_3d = Discriminator3D(**discr_3d_kwargs)
|
|
self.discr_3d = instantiate_from_config(discr_3d_kwargs)
|
|
else:
|
|
self.discr_3d = None
|
|
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
|
|
|
|
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
|
|
|
def get_trainable_params(self) -> Any:
|
|
params = []
|
|
if self.discr is not None:
|
|
params += list(self.discr.parameters())
|
|
if self.discr_3d is not None:
|
|
params += list(self.discr_3d.parameters())
|
|
# if self.multiscale_discrs is not None:
|
|
# for discr in self.multiscale_discrs:
|
|
# params += list(discr.parameters())
|
|
return params
|
|
|
|
def get_trainable_parameters(self) -> Any:
|
|
return self.get_trainable_params()
|
|
|
|
def forward(
|
|
self,
|
|
inputs,
|
|
reconstructions,
|
|
optimizer_idx,
|
|
global_step,
|
|
aux_losses=None,
|
|
last_layer=None,
|
|
split="train",
|
|
):
|
|
batch, channels, frames = inputs.shape[:3]
|
|
|
|
if optimizer_idx == 0:
|
|
recon_loss = F.mse_loss(inputs, reconstructions)
|
|
|
|
if self.perceptual_weight > 0:
|
|
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
|
|
|
|
input_frames = pick_video_frame(inputs, frame_indices)
|
|
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
|
|
|
perceptual_loss = self.perceptual_model(
|
|
input_frames.contiguous(), recon_frames.contiguous()
|
|
).mean()
|
|
else:
|
|
perceptual_loss = self.zero
|
|
|
|
if (
|
|
global_step >= self.disc_start
|
|
or not self.training
|
|
or self.adversarial_loss_weight == 0
|
|
):
|
|
gen_loss = self.zero
|
|
adaptive_weight = 0
|
|
else:
|
|
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
|
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
|
|
|
|
# fake_logits = self.discr(recon_video_frames)
|
|
fake_logits = self.discr_3d(reconstructions)
|
|
gen_loss = hinge_gen_loss(fake_logits)
|
|
|
|
adaptive_weight = 1
|
|
if self.perceptual_weight > 0 and last_layer is not None:
|
|
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
|
|
perceptual_loss, last_layer
|
|
).norm(p=2)
|
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
|
|
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
|
|
min=1e-3
|
|
)
|
|
adaptive_weight.clamp_(max=1e3)
|
|
|
|
if torch.isnan(adaptive_weight).any():
|
|
adaptive_weight = 1
|
|
|
|
# multiscale discriminator losses
|
|
|
|
# multiscale_gen_losses = []
|
|
# multiscale_gen_adaptive_weights = []
|
|
# if self.multiscale_adversarial_loss_weight > 0:
|
|
# if not exists(recon_video_frames):
|
|
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
|
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
|
|
# for discr in self.multiscale_discrs:
|
|
# fake_logits = recon_video_frames
|
|
|
|
# multiscale_gen_loss = hinge_gen_loss(fake_logits)
|
|
# multiscale_gen_losses.append(multiscale_gen_loss)
|
|
|
|
# multiscale_adaptive_weight = 1.
|
|
|
|
# if exists(norm_grad_wrt_perceptual_loss):
|
|
# norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2)
|
|
# multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
|
|
# multiscale_adaptive_weight.clamp_(max = 1e3)
|
|
|
|
# multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
|
|
# weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))
|
|
# else:
|
|
# weighted_multiscale_gen_losses = self.zero
|
|
|
|
if aux_losses is None:
|
|
aux_losses = self.zero
|
|
|
|
total_loss = (
|
|
recon_loss
|
|
+ aux_losses * self.quantizer_aux_loss_weight
|
|
+ perceptual_loss * self.perceptual_weight
|
|
+ gen_loss * self.adversarial_loss_weight
|
|
)
|
|
# gen_loss * adaptive_weight * self.adversarial_loss_weight + \
|
|
# weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight
|
|
|
|
log = {
|
|
"{}/total_loss".format(split): total_loss.detach(),
|
|
"{}/recon_loss".format(split): recon_loss.detach(),
|
|
"{}/perceptual_loss".format(split): perceptual_loss.detach(),
|
|
"{}/gen_loss".format(split): gen_loss.detach(),
|
|
"{}/aux_losses".format(split): aux_losses.detach(),
|
|
# "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(),
|
|
"{}/adaptive_weight".format(split): adaptive_weight,
|
|
# "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights),
|
|
}
|
|
|
|
return total_loss, log
|
|
|
|
if optimizer_idx == 1:
|
|
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
|
|
|
# real = pick_video_frame(inputs, frame_indices)
|
|
# fake = pick_video_frame(reconstructions, frame_indices)
|
|
|
|
# apply_gradient_penalty = self.grad_penalty_loss_weight > 0
|
|
# if apply_gradient_penalty:
|
|
# real = real.requires_grad_()
|
|
|
|
# real_logits = self.discr(real)
|
|
# fake_logits = self.discr(fake.detach())
|
|
|
|
apply_gradient_penalty = self.grad_penalty_loss_weight > 0
|
|
if apply_gradient_penalty:
|
|
inputs = inputs.requires_grad_()
|
|
real_logits = self.discr_3d(inputs)
|
|
fake_logits = self.discr_3d(reconstructions.detach())
|
|
|
|
discr_loss = hinge_discr_loss(fake_logits, real_logits)
|
|
|
|
# # multiscale discriminators
|
|
# multiscale_discr_losses = []
|
|
# if self.multiscale_adversarial_loss_weight > 0:
|
|
# for discr in self.multiscale_discrs:
|
|
# multiscale_real_logits = discr(inputs)
|
|
# multiscale_fake_logits = discr(reconstructions.detach())
|
|
|
|
# multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
|
|
# multiscale_discr_losses.append(multiscale_discr_loss)
|
|
# else:
|
|
# multiscale_discr_losses.append(self.zero)
|
|
|
|
# gradient penalty
|
|
if apply_gradient_penalty:
|
|
# gradient_penalty_loss = gradient_penalty(real, real_logits)
|
|
gradient_penalty_loss = gradient_penalty(inputs, real_logits)
|
|
else:
|
|
gradient_penalty_loss = self.zero
|
|
|
|
total_loss = discr_loss + self.grad_penalty_loss_weight * gradient_penalty_loss
|
|
# self.grad_penalty_loss_weight * gradient_penalty_loss + \
|
|
# sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
|
|
|
|
log = {
|
|
"{}/total_disc_loss".format(split): total_loss.detach(),
|
|
"{}/discr_loss".format(split): discr_loss.detach(),
|
|
"{}/grad_penalty_loss".format(split): gradient_penalty_loss.detach(),
|
|
# "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(),
|
|
"{}/logits_real".format(split): real_logits.detach().mean(),
|
|
"{}/logits_fake".format(split): fake_logits.detach().mean(),
|
|
}
|
|
return total_loss, log
|