diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 10635b4..226ed6e 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -192,13 +192,13 @@ class SATVideoDiffusionEngine(nn.Module): for i in range(fake_cp_size): end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0) - fake_cp_rank0 = True if i == 0 else False + use_cp = True if i == 0 else False clear_fake_cp_cache = True if i == fake_cp_size - 1 else False with torch.no_grad(): recon = self.first_stage_model.decode( z_now[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache, - fake_cp_rank0=fake_cp_rank0, + use_cp=use_cp, ) recons.append(recon) start_frame = end_frame diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py index 1d9c34f..094f5ae 100644 --- a/sat/vae_modules/cp_enc_dec.py +++ b/sat/vae_modules/cp_enc_dec.py @@ -101,8 +101,6 @@ def _gather(input_, dim): group = get_context_parallel_group() cp_rank = get_context_parallel_rank() - # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) - input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() @@ -127,12 +125,9 @@ def _gather(input_, dim): def _conv_split(input_, dim, kernel_size): cp_world_size = get_context_parallel_world_size() - # Bypass the function if context parallel is 1 if cp_world_size == 1: return input_ - # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) - cp_rank = get_context_parallel_rank() dim_size = (input_.size()[dim] - kernel_size) // cp_world_size @@ -140,14 +135,11 @@ def _conv_split(input_, dim, kernel_size): if cp_rank == 0: output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) else: - # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) output = input_.transpose(dim, 0)[ cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size ].transpose(dim, 0) output = output.contiguous() - # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) - return output @@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size): group = get_context_parallel_group() cp_rank = get_context_parallel_rank() - - # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) - input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() @@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non if recv_rank % cp_world_size == cp_world_size - 1: recv_rank += cp_world_size - # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) - # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() - # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group) - # req_recv.wait() recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() if cp_rank < cp_world_size - 1: req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) if cp_rank > 0: req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) - # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) - # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + if cp_rank == 0: if cache_padding is not None: @@ -421,7 +405,6 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm): def Normalize(in_channels, gather=False, **kwargs): - # same for 3D and 2D if gather: return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) else: @@ -468,8 +451,8 @@ class SpatialNorm3D(nn.Module): kernel_size=1, ) - def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True): - if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0: + def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True): + if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] @@ -531,10 +514,11 @@ class Upsample3D(nn.Module): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time - def forward(self, x, fake_cp_rank0=True): + def forward(self, x, fake_cp=True): if self.compress_time and x.shape[2] > 1: - if get_context_parallel_rank() == 0 and fake_cp_rank0: - # print(x.shape) + if get_context_parallel_rank() == 0 and fake_cp: + print(x.shape) + breakpoint() # split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:] @@ -545,8 +529,6 @@ class Upsample3D(nn.Module): torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits ] x_rest = torch.cat(interpolated_splits, dim=1) - - # x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: splits = torch.split(x, 32, dim=1) @@ -555,13 +537,10 @@ class Upsample3D(nn.Module): ] x = torch.cat(interpolated_splits, dim=1) - # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - else: # only interpolate 2D t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") - # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") splits = torch.split(x, 32, dim=1) interpolated_splits = [ @@ -590,12 +569,12 @@ class DownSample3D(nn.Module): self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time - def forward(self, x, fake_cp_rank0=True): + def forward(self, x, fake_cp=True): if self.compress_time and x.shape[2] > 1: h, w = x.shape[-2:] x = rearrange(x, "b c t h w -> (b h w) c t") - if get_context_parallel_rank() == 0 and fake_cp_rank0: + if get_context_parallel_rank() == 0 and fake_cp: # split first frame x_first, x_rest = x[..., 0], x[..., 1:] @@ -693,17 +672,13 @@ class ContextParallelResnetBlock3D(nn.Module): padding=0, ) - def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True): + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True): h = x - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) else: h = self.norm1(h) - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.conv1(h, clear_cache=clear_fake_cp_cache) @@ -711,14 +686,10 @@ class ContextParallelResnetBlock3D(nn.Module): if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) else: h = self.norm2(h) - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.dropout(h) @@ -827,32 +798,33 @@ class ContextParallelEncoder3D(nn.Module): kernel_size=3, ) - def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True): + def forward(self, x, use_cp=True): + global _USE_CP + _USE_CP = use_cp + # timestep embedding temb = None # downsampling - h = self.conv_in(x, clear_cache=clear_fake_cp_cache) + hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: - print("Attention not implemented") h = self.down[i_level].attn[i_block](h) + hs.append(h) if i_level != self.num_resolutions - 1: - h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0) + hs.append(self.down[i_level].downsample(hs[-1])) # middle - h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) - h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) # end - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) h = self.norm_out(h) - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - h = nonlinearity(h) - h = self.conv_out(h, clear_cache=clear_fake_cp_cache) + h = self.conv_out(h) return h @@ -895,11 +867,9 @@ class ContextParallelDecoder3D(nn.Module): zq_ch = z_channels # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) self.conv_in = ContextParallelCausalConv3d( chan_in=z_channels, @@ -955,11 +925,6 @@ class ContextParallelDecoder3D(nn.Module): up.block = block up.attn = attn if i_level != 0: - # # Symmetrical enc-dec - if i_level <= self.temporal_compress_level: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) - else: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) if i_level < self.num_resolutions - self.temporal_compress_level: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) else: @@ -974,7 +939,9 @@ class ContextParallelDecoder3D(nn.Module): kernel_size=3, ) - def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True): + def forward(self, z, clear_fake_cp_cache=True, use_cp=True): + global _USE_CP + _USE_CP = use_cp self.last_z_shape = z.shape # timestep embedding @@ -987,25 +954,25 @@ class ContextParallelDecoder3D(nn.Module): h = self.conv_in(z, clear_cache=clear_fake_cp_cache) # middle - h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) - h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) + h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block]( - h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0 + h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, zq) if i_level != 0: - h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0) + h = self.up[i_level].upsample(h, fake_cp=use_cp) # end if self.give_pre_end: return h - h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) h = nonlinearity(h) h = self.conv_out(h, clear_cache=clear_fake_cp_cache)