mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
remove wrong fake_cp
This commit is contained in:
parent
d8ee013842
commit
e7bcecf947
@ -192,13 +192,13 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
for i in range(fake_cp_size):
|
||||
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
|
||||
|
||||
fake_cp_rank0 = True if i == 0 else False
|
||||
use_cp = True if i == 0 else False
|
||||
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
||||
with torch.no_grad():
|
||||
recon = self.first_stage_model.decode(
|
||||
z_now[:, :, start_frame:end_frame].contiguous(),
|
||||
clear_fake_cp_cache=clear_fake_cp_cache,
|
||||
fake_cp_rank0=fake_cp_rank0,
|
||||
use_cp=use_cp,
|
||||
)
|
||||
recons.append(recon)
|
||||
start_frame = end_frame
|
||||
|
@ -101,8 +101,6 @@ def _gather(input_, dim):
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
||||
if cp_rank == 0:
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
@ -127,12 +125,9 @@ def _gather(input_, dim):
|
||||
def _conv_split(input_, dim, kernel_size):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
# Bypass the function if context parallel is 1
|
||||
if cp_world_size == 1:
|
||||
return input_
|
||||
|
||||
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
|
||||
@ -140,14 +135,11 @@ def _conv_split(input_, dim, kernel_size):
|
||||
if cp_rank == 0:
|
||||
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
||||
else:
|
||||
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
|
||||
output = input_.transpose(dim, 0)[
|
||||
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
|
||||
].transpose(dim, 0)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size):
|
||||
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
|
||||
if cp_rank == 0:
|
||||
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
|
||||
@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
|
||||
if recv_rank % cp_world_size == cp_world_size - 1:
|
||||
recv_rank += cp_world_size
|
||||
|
||||
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
||||
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
|
||||
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
|
||||
# req_recv.wait()
|
||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||
if cp_rank < cp_world_size - 1:
|
||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||
if cp_rank > 0:
|
||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
||||
# req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
|
||||
|
||||
if cp_rank == 0:
|
||||
if cache_padding is not None:
|
||||
@ -421,7 +405,6 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs):
|
||||
# same for 3D and 2D
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
else:
|
||||
@ -468,8 +451,8 @@ class SpatialNorm3D(nn.Module):
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True):
|
||||
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
@ -531,10 +514,11 @@ class Upsample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
def forward(self, x, fake_cp=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
# print(x.shape)
|
||||
if get_context_parallel_rank() == 0 and fake_cp:
|
||||
print(x.shape)
|
||||
breakpoint()
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
@ -545,8 +529,6 @@ class Upsample3D(nn.Module):
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
@ -555,13 +537,10 @@ class Upsample3D(nn.Module):
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
@ -590,12 +569,12 @@ class DownSample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
def forward(self, x, fake_cp=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
if get_context_parallel_rank() == 0 and fake_cp:
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
@ -693,17 +672,13 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True):
|
||||
h = x
|
||||
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
|
||||
else:
|
||||
h = self.norm1(h)
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
|
||||
@ -711,14 +686,10 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
|
||||
else:
|
||||
h = self.norm2(h)
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
@ -827,32 +798,33 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
def forward(self, x, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
print("Attention not implemented")
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
h = self.norm_out(h)
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
@ -895,11 +867,9 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
self.conv_in = ContextParallelCausalConv3d(
|
||||
chan_in=z_channels,
|
||||
@ -955,11 +925,6 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
# # Symmetrical enc-dec
|
||||
if i_level <= self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
@ -974,7 +939,9 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
def forward(self, z, clear_fake_cp_cache=True, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
@ -987,25 +954,25 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
|
||||
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp
|
||||
)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.up[i_level].upsample(h, fake_cp=use_cp)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user