Update cp_enc_dec.py

This commit is contained in:
zR 2024-11-08 23:27:44 +08:00
parent e7bcecf947
commit a8205b575d

View File

@ -517,11 +517,8 @@ class Upsample3D(nn.Module):
def forward(self, x, fake_cp=True):
if self.compress_time and x.shape[2] > 1:
if get_context_parallel_rank() == 0 and fake_cp:
print(x.shape)
breakpoint()
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
splits = torch.split(x_rest, 32, dim=1)