remove wrong fake_cp

This commit is contained in:
zR 2024-11-08 22:54:17 +08:00
parent d8ee013842
commit e7bcecf947
2 changed files with 34 additions and 67 deletions

View File

@ -192,13 +192,13 @@ class SATVideoDiffusionEngine(nn.Module):
for i in range(fake_cp_size):
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
fake_cp_rank0 = True if i == 0 else False
use_cp = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
with torch.no_grad():
recon = self.first_stage_model.decode(
z_now[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache,
fake_cp_rank0=fake_cp_rank0,
use_cp=use_cp,
)
recons.append(recon)
start_frame = end_frame

View File

@ -101,8 +101,6 @@ def _gather(input_, dim):
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
@ -127,12 +125,9 @@ def _gather(input_, dim):
def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
@ -140,14 +135,11 @@ def _conv_split(input_, dim, kernel_size):
if cp_rank == 0:
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else:
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
].transpose(dim, 0)
output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size):
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
# req_recv.wait()
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
if cache_padding is not None:
@ -421,7 +405,6 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
def Normalize(in_channels, gather=False, **kwargs):
# same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
@ -468,8 +451,8 @@ class SpatialNorm3D(nn.Module):
kernel_size=1,
)
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True):
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
@ -531,10 +514,11 @@ class Upsample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
def forward(self, x, fake_cp_rank0=True):
def forward(self, x, fake_cp=True):
if self.compress_time and x.shape[2] > 1:
if get_context_parallel_rank() == 0 and fake_cp_rank0:
# print(x.shape)
if get_context_parallel_rank() == 0 and fake_cp:
print(x.shape)
breakpoint()
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
@ -545,8 +529,6 @@ class Upsample3D(nn.Module):
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x_rest = torch.cat(interpolated_splits, dim=1)
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
splits = torch.split(x, 32, dim=1)
@ -555,13 +537,10 @@ class Upsample3D(nn.Module):
]
x = torch.cat(interpolated_splits, dim=1)
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
@ -590,12 +569,12 @@ class DownSample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
def forward(self, x, fake_cp_rank0=True):
def forward(self, x, fake_cp=True):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
if get_context_parallel_rank() == 0 and fake_cp_rank0:
if get_context_parallel_rank() == 0 and fake_cp:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
@ -693,17 +672,13 @@ class ContextParallelResnetBlock3D(nn.Module):
padding=0,
)
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True):
h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
else:
h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
@ -711,14 +686,10 @@ class ContextParallelResnetBlock3D(nn.Module):
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
else:
h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.dropout(h)
@ -827,32 +798,33 @@ class ContextParallelEncoder3D(nn.Module):
kernel_size=3,
)
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
def forward(self, x, use_cp=True):
global _USE_CP
_USE_CP = use_cp
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
print("Attention not implemented")
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
h = self.conv_out(h)
return h
@ -895,11 +867,9 @@ class ContextParallelDecoder3D(nn.Module):
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
@ -955,11 +925,6 @@ class ContextParallelDecoder3D(nn.Module):
up.block = block
up.attn = attn
if i_level != 0:
# # Symmetrical enc-dec
if i_level <= self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else:
@ -974,7 +939,9 @@ class ContextParallelDecoder3D(nn.Module):
kernel_size=3,
)
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
def forward(self, z, clear_fake_cp_cache=True, use_cp=True):
global _USE_CP
_USE_CP = use_cp
self.last_z_shape = z.shape
# timestep embedding
@ -987,25 +954,25 @@ class ContextParallelDecoder3D(nn.Module):
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
h = self.up[i_level].upsample(h, fake_cp=use_cp)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)