mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Add type conversion and validation checks
This commit is contained in:
parent
fa4659fb2c
commit
2a6cca0656
@ -33,7 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
|
||||
########## preprocessors ##########
|
||||
|
||||
def preprocess_image_with_resize(
|
||||
image_path: Path,
|
||||
image_path: Path | str,
|
||||
height: int,
|
||||
width: int,
|
||||
) -> torch.Tensor:
|
||||
@ -51,6 +51,8 @@ def preprocess_image_with_resize(
|
||||
H = height
|
||||
W = width
|
||||
"""
|
||||
if isinstance(image_path, str):
|
||||
image_path = Path(image_path)
|
||||
image = cv2.imread(image_path.as_posix())
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (width, height))
|
||||
@ -60,7 +62,7 @@ def preprocess_image_with_resize(
|
||||
|
||||
|
||||
def preprocess_video_with_resize(
|
||||
video_path: Path,
|
||||
video_path: Path | str,
|
||||
max_num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
@ -85,8 +87,12 @@ def preprocess_video_with_resize(
|
||||
H = height
|
||||
W = width
|
||||
"""
|
||||
if isinstance(video_path, str):
|
||||
video_path = Path(video_path)
|
||||
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
||||
video_num_frames = len(video_reader)
|
||||
if video_num_frames < max_num_frames:
|
||||
raise ValueError(f"video's frames is less than {max_num_frames}.")
|
||||
|
||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||
frames = video_reader.get_batch(indices)
|
||||
|
Loading…
x
Reference in New Issue
Block a user