Add type conversion and validation checks

This commit is contained in:
OleehyO 2024-12-30 06:53:23 +00:00
parent fa4659fb2c
commit 2a6cca0656

View File

@ -33,7 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
########## preprocessors ##########
def preprocess_image_with_resize(
image_path: Path,
image_path: Path | str,
height: int,
width: int,
) -> torch.Tensor:
@ -51,6 +51,8 @@ def preprocess_image_with_resize(
H = height
W = width
"""
if isinstance(image_path, str):
image_path = Path(image_path)
image = cv2.imread(image_path.as_posix())
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (width, height))
@ -60,7 +62,7 @@ def preprocess_image_with_resize(
def preprocess_video_with_resize(
video_path: Path,
video_path: Path | str,
max_num_frames: int,
height: int,
width: int,
@ -85,8 +87,12 @@ def preprocess_video_with_resize(
H = height
W = width
"""
if isinstance(video_path, str):
video_path = Path(video_path)
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
video_num_frames = len(video_reader)
if video_num_frames < max_num_frames:
raise ValueError(f"video's frames is less than {max_num_frames}.")
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)