mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Update cp_enc_dec.py
This commit is contained in:
parent
e7bcecf947
commit
a8205b575d
@ -517,11 +517,8 @@ class Upsample3D(nn.Module):
|
|||||||
def forward(self, x, fake_cp=True):
|
def forward(self, x, fake_cp=True):
|
||||||
if self.compress_time and x.shape[2] > 1:
|
if self.compress_time and x.shape[2] > 1:
|
||||||
if get_context_parallel_rank() == 0 and fake_cp:
|
if get_context_parallel_rank() == 0 and fake_cp:
|
||||||
print(x.shape)
|
|
||||||
breakpoint()
|
|
||||||
# split first frame
|
# split first frame
|
||||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||||
|
|
||||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||||
|
|
||||||
splits = torch.split(x_rest, 32, dim=1)
|
splits = torch.split(x_rest, 32, dim=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user