Yuxuan Zhang 39c6562dc8 format
2025-03-22 15:14:06 +08:00

135 lines
5.0 KiB
Python

from typing import List, Optional, Union
import torch
import torch.nn as nn
from omegaconf import ListConfig
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
from sat import mpu
class StandardDiffusionLoss(nn.Module):
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
):
super().__init__()
assert type in ["l2", "l1", "lpips"]
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.type = type
self.offset_noise_level = offset_noise_level
if type == "lpips":
self.lpips = LPIPS().eval()
if not batch2model_keys:
batch2model_keys = []
if isinstance(batch2model_keys, str):
batch2model_keys = [batch2model_keys]
self.batch2model_keys = set(batch2model_keys)
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
noise = (
noise
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
* self.offset_noise_level
)
noise = noise.to(input.dtype)
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs)
w = append_dims(denoiser.w(sigmas), input.ndim)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
elif self.type == "l1":
return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
class VideoDiffusionLoss(StandardDiffusionLoss):
def __init__(
self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs
):
self.fixed_frames = fixed_frames
self.block_scale = block_scale
self.block_size = block_size
self.min_snr_value = min_snr_value
super().__init__(**kwargs)
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
idx = idx.to(input.device)
noise = torch.randn_like(input)
# broadcast noise
mp_size = mpu.get_model_parallel_world_size()
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
torch.distributed.broadcast(
alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()
)
additional_model_inputs["idx"] = idx
if self.offset_noise_level > 0.0:
noise = (
noise
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
* self.offset_noise_level
)
noised_input = input.float() * append_dims(
alphas_cumprod_sqrt, input.ndim
) + noise * append_dims((1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim)
if "concat_images" in batch.keys():
cond["concat"] = batch["concat_images"]
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
model_output = denoiser(
network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs
)
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
if self.min_snr_value is not None:
w = min(w, self.min_snr_value)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
elif self.type == "l1":
return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss