mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
commit
2825d9b707
@ -378,8 +378,9 @@ class SFTDataset(Dataset):
|
||||
num_frames = max_num_frames
|
||||
start = int(skip_frms_num)
|
||||
end = int(start + num_frames / fps * actual_fps)
|
||||
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
end_safty = min(int(start + num_frames / fps * actual_fps), int(ori_vlen))
|
||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
@ -388,7 +389,7 @@ class SFTDataset(Dataset):
|
||||
num_frames = max_num_frames
|
||||
start = int(skip_frms_num)
|
||||
end = int(ori_vlen - skip_frms_num)
|
||||
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
|
||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
|
Loading…
x
Reference in New Issue
Block a user