From 2a6cca0656e1c91db466d77ec58cd77af1ffc510 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 06:53:23 +0000 Subject: [PATCH] Add type conversion and validation checks --- finetune/datasets/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py index cf0525e..ba7bddf 100644 --- a/finetune/datasets/utils.py +++ b/finetune/datasets/utils.py @@ -33,7 +33,7 @@ def load_images(image_path: Path) -> List[Path]: ########## preprocessors ########## def preprocess_image_with_resize( - image_path: Path, + image_path: Path | str, height: int, width: int, ) -> torch.Tensor: @@ -51,6 +51,8 @@ def preprocess_image_with_resize( H = height W = width """ + if isinstance(image_path, str): + image_path = Path(image_path) image = cv2.imread(image_path.as_posix()) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (width, height)) @@ -60,7 +62,7 @@ def preprocess_image_with_resize( def preprocess_video_with_resize( - video_path: Path, + video_path: Path | str, max_num_frames: int, height: int, width: int, @@ -85,8 +87,12 @@ def preprocess_video_with_resize( H = height W = width """ + if isinstance(video_path, str): + video_path = Path(video_path) video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) video_num_frames = len(video_reader) + if video_num_frames < max_num_frames: + raise ValueError(f"video's frames is less than {max_num_frames}.") indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) frames = video_reader.get_batch(indices)