mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
78 lines
2.2 KiB
Python
78 lines
2.2 KiB
Python
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
import torch
|
|
|
|
|
|
def gauss_kernel(size=5, channels=3):
|
|
kernel = torch.tensor(
|
|
[
|
|
[1.0, 4.0, 6.0, 4.0, 1],
|
|
[4.0, 16.0, 24.0, 16.0, 4.0],
|
|
[6.0, 24.0, 36.0, 24.0, 6.0],
|
|
[4.0, 16.0, 24.0, 16.0, 4.0],
|
|
[1.0, 4.0, 6.0, 4.0, 1.0],
|
|
]
|
|
)
|
|
kernel /= 256.0
|
|
kernel = kernel.repeat(channels, 1, 1, 1)
|
|
kernel = kernel.to(device)
|
|
return kernel
|
|
|
|
|
|
def downsample(x):
|
|
return x[:, :, ::2, ::2]
|
|
|
|
|
|
def upsample(x):
|
|
cc = torch.cat(
|
|
[x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3
|
|
)
|
|
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
|
cc = cc.permute(0, 1, 3, 2)
|
|
cc = torch.cat(
|
|
[cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3
|
|
)
|
|
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
|
x_up = cc.permute(0, 1, 3, 2)
|
|
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
|
|
|
|
|
|
def conv_gauss(img, kernel):
|
|
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect")
|
|
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
|
|
return out
|
|
|
|
|
|
def laplacian_pyramid(img, kernel, max_levels=3):
|
|
current = img
|
|
pyr = []
|
|
for level in range(max_levels):
|
|
filtered = conv_gauss(current, kernel)
|
|
down = downsample(filtered)
|
|
up = upsample(down)
|
|
diff = current - up
|
|
pyr.append(diff)
|
|
current = down
|
|
return pyr
|
|
|
|
|
|
class LapLoss(torch.nn.Module):
|
|
def __init__(self, max_levels=5, channels=3):
|
|
super(LapLoss, self).__init__()
|
|
self.max_levels = max_levels
|
|
self.gauss_kernel = gauss_kernel(channels=channels)
|
|
|
|
def forward(self, input, target):
|
|
pyr_input = laplacian_pyramid(
|
|
img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
|
|
)
|
|
pyr_target = laplacian_pyramid(
|
|
img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
|
|
)
|
|
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
|