mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Add type conversion and validation checks
This commit is contained in:
parent
fa4659fb2c
commit
2a6cca0656
@ -33,7 +33,7 @@ def load_images(image_path: Path) -> List[Path]:
|
|||||||
########## preprocessors ##########
|
########## preprocessors ##########
|
||||||
|
|
||||||
def preprocess_image_with_resize(
|
def preprocess_image_with_resize(
|
||||||
image_path: Path,
|
image_path: Path | str,
|
||||||
height: int,
|
height: int,
|
||||||
width: int,
|
width: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -51,6 +51,8 @@ def preprocess_image_with_resize(
|
|||||||
H = height
|
H = height
|
||||||
W = width
|
W = width
|
||||||
"""
|
"""
|
||||||
|
if isinstance(image_path, str):
|
||||||
|
image_path = Path(image_path)
|
||||||
image = cv2.imread(image_path.as_posix())
|
image = cv2.imread(image_path.as_posix())
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
image = cv2.resize(image, (width, height))
|
image = cv2.resize(image, (width, height))
|
||||||
@ -60,7 +62,7 @@ def preprocess_image_with_resize(
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_video_with_resize(
|
def preprocess_video_with_resize(
|
||||||
video_path: Path,
|
video_path: Path | str,
|
||||||
max_num_frames: int,
|
max_num_frames: int,
|
||||||
height: int,
|
height: int,
|
||||||
width: int,
|
width: int,
|
||||||
@ -85,8 +87,12 @@ def preprocess_video_with_resize(
|
|||||||
H = height
|
H = height
|
||||||
W = width
|
W = width
|
||||||
"""
|
"""
|
||||||
|
if isinstance(video_path, str):
|
||||||
|
video_path = Path(video_path)
|
||||||
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
||||||
video_num_frames = len(video_reader)
|
video_num_frames = len(video_reader)
|
||||||
|
if video_num_frames < max_num_frames:
|
||||||
|
raise ValueError(f"video's frames is less than {max_num_frames}.")
|
||||||
|
|
||||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||||
frames = video_reader.get_batch(indices)
|
frames = video_reader.get_batch(indices)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user