mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
🐛 fix frame padding bug
This commit is contained in:
parent
da1af26d57
commit
b10e444d94
@ -145,9 +145,10 @@ def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
||||
|
||||
def pad_last_frame(tensor, num_frames):
|
||||
# T, H, W, C
|
||||
if tensor.shape[0] < num_frames:
|
||||
last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
|
||||
padded_tensor = torch.cat([tensor, last_frame], dim=0)
|
||||
if len(tensor) < num_frames:
|
||||
pad_length = num_frames - len(tensor)
|
||||
pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device)
|
||||
padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
|
||||
return padded_tensor
|
||||
else:
|
||||
return tensor[:num_frames]
|
||||
@ -418,7 +419,7 @@ class SFTDataset(Dataset):
|
||||
)
|
||||
|
||||
tensor_frms = pad_last_frame(
|
||||
tensor_frms, num_frames
|
||||
tensor_frms, max_num_frames
|
||||
) # the len of indices may be less than num_frames, due to round error
|
||||
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
||||
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
|
||||
|
Loading…
x
Reference in New Issue
Block a user