diff --git a/requirements.txt b/requirements.txt index 99a6fc0..c58aff5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -diffusers>=0.30.0 -transformers>=4.43.4 -accelerate>=0.33.0 +diffusers==0.30.0 +transformers==4.44.0 +accelerate==0.33.0 sentencepiece==0.2.0 # T5 SwissArmyTransformer==0.4.11 # Inference torch==2.4.0 # Tested in 2.2 2.3 2.4 and 2.5 diff --git a/sat/requirements.txt b/sat/requirements.txt index 3ab573f..253a85a 100644 --- a/sat/requirements.txt +++ b/sat/requirements.txt @@ -1,17 +1,16 @@ SwissArmyTransformer==0.4.11 -diffusers>=0.29.2 -omegaconf>=2.3.0 -torch>=2.3.1 -torchvision>=0.19.0 -pytorch_lightning>=2.3.3 -kornia>=0.7.3 -beartype>=0.18.5 -numpy>=2.0.1 -fsspec>=2024.5.0 -safetensors>=0.4.3 -imageio-ffmpeg>=0.5.1 -imageio>=2.34.2 -scipy>=1.14.0 -decord>=0.6.0 -wandb>=0.17.5 -deepspeed>=0.14.4 +omegaconf==2.3.0 +torch==2.4.0 +torchvision==0.19.0 +pytorch_lightning==2.3.3 +kornia==0.7.3 +beartype==0.18.5 +numpy==2.0.1 +fsspec==2024.5.0 +safetensors==0.4.3 +imageio-ffmpeg==0.5.1 +imageio==2.34.2 +scipy==1.14.0 +decord==0.6.0 +wandb==0.17.5 +deepspeed==0.14.4 diff --git a/sat/sgm/modules/cp_enc_dec.py b/sat/sgm/modules/cp_enc_dec.py index 9a65d61..469595d 100644 --- a/sat/sgm/modules/cp_enc_dec.py +++ b/sat/sgm/modules/cp_enc_dec.py @@ -2,25 +2,12 @@ import math import torch import torch.distributed import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -from beartype import beartype -from beartype.typing import Union, Tuple, Optional, List -from einops import rearrange - from ..util import ( get_context_parallel_group, get_context_parallel_rank, get_context_parallel_world_size, - get_context_parallel_group_rank, -) -# try: -from ..util import SafeConv3d as Conv3d -# except: -# # Degrade to normal Conv3d if SafeConv3d is not available -# from torch.nn import Conv3d +) _USE_CP = True @@ -192,706 +179,4 @@ def _conv_gather(input_, dim, kernel_size): # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) - return output - - -def _pass_from_previous_rank(input_, dim, kernel_size): - # Bypass the function if kernel size is 1 - if kernel_size == 1: - return input_ - - group = get_context_parallel_group() - cp_rank = get_context_parallel_rank() - cp_group_rank = get_context_parallel_group_rank() - cp_world_size = get_context_parallel_world_size() - - # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) - - global_rank = torch.distributed.get_rank() - global_world_size = torch.distributed.get_world_size() - - input_ = input_.transpose(0, dim) - - # pass from last rank - send_rank = global_rank + 1 - recv_rank = global_rank - 1 - if send_rank % cp_world_size == 0: - send_rank -= cp_world_size - if recv_rank % cp_world_size == cp_world_size - 1: - recv_rank += cp_world_size - - if cp_rank < cp_world_size - 1: - req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) - if cp_rank > 0: - recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() - req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) - - if cp_rank == 0: - input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) - else: - req_recv.wait() - input_ = torch.cat([recv_buffer, input_], dim=0) - - input_ = input_.transpose(0, dim).contiguous() - - # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) - - return input_ - - -def _drop_from_previous_rank(input_, dim, kernel_size): - input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) - return input_ - - -class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, dim, kernel_size): - ctx.dim = dim - ctx.kernel_size = kernel_size - return _conv_split(input_, dim, kernel_size) - - @staticmethod - def backward(ctx, grad_output): - return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None - - -class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, dim, kernel_size): - ctx.dim = dim - ctx.kernel_size = kernel_size - return _conv_gather(input_, dim, kernel_size) - - @staticmethod - def backward(ctx, grad_output): - return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None - - -class _ConvolutionPassFromPreviousRank(torch.autograd.Function): - @staticmethod - def forward(ctx, input_, dim, kernel_size): - ctx.dim = dim - ctx.kernel_size = kernel_size - return _pass_from_previous_rank(input_, dim, kernel_size) - - @staticmethod - def backward(ctx, grad_output): - return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None - - -def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): - return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) - - -def conv_gather_from_context_parallel_region(input_, dim, kernel_size): - return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) - - -def conv_pass_from_last_rank(input_, dim, kernel_size): - return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) - - -class ContextParallelCausalConv3d(nn.Module): - def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): - super().__init__() - kernel_size = cast_tuple(kernel_size, 3) - - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - - assert is_odd(height_kernel_size) and is_odd(width_kernel_size) - - time_pad = time_kernel_size - 1 - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 - - self.height_pad = height_pad - self.width_pad = width_pad - self.time_pad = time_pad - self.time_kernel_size = time_kernel_size - self.temporal_dim = 2 - - stride = (stride, stride, stride) - dilation = (1, 1, 1) - self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) - - def forward(self, input_): - # temporal padding inside - if _USE_CP: - input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size) - else: - input_ = input_.transpose(0, self.temporal_dim) - input_parallel = torch.cat([input_[:1]] * (self.time_kernel_size - 1) + [input_], dim=0) - input_parallel = input_parallel.transpose(0, self.temporal_dim) - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) - output_parallel = self.conv(input_parallel) - output = output_parallel - return output - - -class ContextParallelGroupNorm(torch.nn.GroupNorm): - def forward(self, input_): - if _USE_CP: - input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1) - output = super().forward(input_) - if _USE_CP: - output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1) - return output - - -def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D - if gather: - return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - else: - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class SpatialNorm3D(nn.Module): - def __init__( - self, - f_channels, - zq_channels, - freeze_norm_layer=False, - add_conv=False, - pad_mode="constant", - gather=False, - **norm_layer_params, - ): - super().__init__() - if gather: - self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) - else: - self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) - # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) - if freeze_norm_layer: - for p in self.norm_layer.parameters: - p.requires_grad = False - - self.add_conv = add_conv - if add_conv: - self.conv = ContextParallelCausalConv3d( - chan_in=zq_channels, - chan_out=zq_channels, - kernel_size=3, - ) - - self.conv_y = ContextParallelCausalConv3d( - chan_in=zq_channels, - chan_out=f_channels, - kernel_size=1, - ) - self.conv_b = ContextParallelCausalConv3d( - chan_in=zq_channels, - chan_out=f_channels, - kernel_size=1, - ) - - def forward(self, f, zq): - if f.shape[2] == 1 and not _USE_CP: - zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") - elif get_context_parallel_rank() == 0: - f_first, f_rest = f[:, :, :1], f[:, :, 1:] - f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] - zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] - zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") - zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") - zq = torch.cat([zq_first, zq_rest], dim=2) - else: - zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") - - if self.add_conv: - zq = self.conv(zq) - - # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1) - norm_f = self.norm_layer(f) - # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1) - - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f - - -def Normalize3D( - in_channels, - zq_ch, - add_conv, - gather=False, -): - return SpatialNorm3D( - in_channels, - zq_ch, - gather=gather, - # norm_layer=nn.GroupNorm, - freeze_norm_layer=False, - add_conv=add_conv, - num_groups=32, - eps=1e-6, - affine=True, - ) - - -class Upsample3D(nn.Module): - def __init__( - self, - in_channels, - with_conv, - compress_time=False, - ): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - self.compress_time = compress_time - - def forward(self, x): - if self.compress_time: - if x.shape[2] == 1 and not _USE_CP: - x = torch.nn.functional.interpolate(x[:, :, 0], scale_factor=2.0, mode="nearest")[:, :, None, :, :] - elif get_context_parallel_rank() == 0: - # split first frame - x_first, x_rest = x[:, :, 0], x[:, :, 1:] - - x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") - x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") - x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) - else: - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - - else: - # only interpolate 2D - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - - if self.with_conv: - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.conv(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - return x - - -class DownSample3D(nn.Module): - def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None): - super().__init__() - self.with_conv = with_conv - if out_channels is None: - out_channels = in_channels - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) - self.compress_time = compress_time - - def forward(self, x): - if self.compress_time and x.shape[2] > 1: - h, w = x.shape[-2:] - x = rearrange(x, "b c t h w -> (b h w) c t") - - if x.shape[-1] % 2 == 1: - # split first frame - x_first, x_rest = x[..., 0], x[..., 1:] - - if x_rest.shape[-1] > 0: - x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) - x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) - else: - x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) - x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) - - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.conv(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - else: - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - return x - - -class ContextParallelResnetBlock3D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - zq_ch=None, - add_conv=False, - gather_norm=False, - normalization=Normalize, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = normalization( - in_channels, - zq_ch=zq_ch, - add_conv=add_conv, - gather=gather_norm, - ) - - self.conv1 = ContextParallelCausalConv3d( - chan_in=in_channels, - chan_out=out_channels, - kernel_size=3, - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = normalization( - out_channels, - zq_ch=zq_ch, - add_conv=add_conv, - gather=gather_norm, - ) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = ContextParallelCausalConv3d( - chan_in=out_channels, - chan_out=out_channels, - kernel_size=3, - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = ContextParallelCausalConv3d( - chan_in=in_channels, - chan_out=out_channels, - kernel_size=3, - ) - else: - self.nin_shortcut = Conv3d( - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - ) - - def forward(self, x, temb, zq=None): - h = x - - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) - if zq is not None: - h = self.norm1(h, zq) - else: - h = self.norm1(h) - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] - - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) - if zq is not None: - h = self.norm2(h, zq) - else: - h = self.norm2(h) - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class ContextParallelEncoder3D(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - pad_mode="first", - temporal_compress_times=4, - gather_norm=False, - **ignore_kwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # log2 of temporal_compress_times - self.temporal_compress_level = int(np.log2(temporal_compress_times)) - - self.conv_in = ContextParallelCausalConv3d( - chan_in=in_channels, - chan_out=self.ch, - kernel_size=3, - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ContextParallelResnetBlock3D( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - temb_channels=self.temb_ch, - gather_norm=gather_norm, - ) - ) - block_in = block_out - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - if i_level < self.temporal_compress_level: - down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) - else: - down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ContextParallelResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - gather_norm=gather_norm, - ) - - self.mid.block_2 = ContextParallelResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - gather_norm=gather_norm, - ) - - # end - self.norm_out = Normalize(block_in, gather=gather_norm) - - self.conv_out = ContextParallelCausalConv3d( - chan_in=block_in, - chan_out=2 * z_channels if double_z else z_channels, - kernel_size=3, - ) - - def forward(self, x, use_cp=True): - global _USE_CP - _USE_CP = use_cp - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.block_2(h, temb) - - # end - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) - h = self.norm_out(h) - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - - h = nonlinearity(h) - h = self.conv_out(h) - - return h - - -class ContextParallelDecoder3D(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - zq_ch=None, - add_conv=False, - pad_mode="first", - temporal_compress_times=4, - gather_norm=False, - **ignorekwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # log2 of temporal_compress_times - self.temporal_compress_level = int(np.log2(temporal_compress_times)) - - if zq_ch is None: - zq_ch = z_channels - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - self.conv_in = ContextParallelCausalConv3d( - chan_in=z_channels, - chan_out=block_in, - kernel_size=3, - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ContextParallelResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - zq_ch=zq_ch, - add_conv=add_conv, - normalization=Normalize3D, - gather_norm=gather_norm, - ) - - self.mid.block_2 = ContextParallelResnetBlock3D( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - zq_ch=zq_ch, - add_conv=add_conv, - normalization=Normalize3D, - gather_norm=gather_norm, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ContextParallelResnetBlock3D( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - zq_ch=zq_ch, - add_conv=add_conv, - normalization=Normalize3D, - gather_norm=gather_norm, - ) - ) - block_in = block_out - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - if i_level < self.num_resolutions - self.temporal_compress_level: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) - else: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) - self.up.insert(0, up) - - self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) - - self.conv_out = ContextParallelCausalConv3d( - chan_in=block_in, - chan_out=out_ch, - kernel_size=3, - ) - - def forward(self, z, use_cp=True): - global _USE_CP - _USE_CP = use_cp - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - t = z.shape[2] - # z to block_in - - zq = z - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb, zq) - h = self.mid.block_2(h, temb, zq) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb, zq) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, zq) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h, zq) - h = nonlinearity(h) - h = self.conv_out(h) - _USE_CP = True - return h - - def get_last_layer(self): - return self.conv_out.conv.weight + return output \ No newline at end of file