mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
🐛 fix frame padding bug
This commit is contained in:
parent
da1af26d57
commit
b10e444d94
@ -145,9 +145,10 @@ def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
|||||||
|
|
||||||
def pad_last_frame(tensor, num_frames):
|
def pad_last_frame(tensor, num_frames):
|
||||||
# T, H, W, C
|
# T, H, W, C
|
||||||
if tensor.shape[0] < num_frames:
|
if len(tensor) < num_frames:
|
||||||
last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
|
pad_length = num_frames - len(tensor)
|
||||||
padded_tensor = torch.cat([tensor, last_frame], dim=0)
|
pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device)
|
||||||
|
padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
else:
|
else:
|
||||||
return tensor[:num_frames]
|
return tensor[:num_frames]
|
||||||
@ -418,7 +419,7 @@ class SFTDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tensor_frms = pad_last_frame(
|
tensor_frms = pad_last_frame(
|
||||||
tensor_frms, num_frames
|
tensor_frms, max_num_frames
|
||||||
) # the len of indices may be less than num_frames, due to round error
|
) # the len of indices may be less than num_frames, due to round error
|
||||||
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
||||||
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
|
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user