From de5bef66112d01a7e8c11f257ebfa8f17c8420c3 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 3 Jan 2025 08:48:42 +0000 Subject: [PATCH] feat(args): add train_resolution validation for video frames and resolution - Add validation to ensure (frames - 1) is multiple of 8 - Add specific resolution check (480x720) for cogvideox-5b models - Add error handling for invalid resolution format --- finetune/schemas/args.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py index fab1139..e96ce91 100644 --- a/finetune/schemas/args.py +++ b/finetune/schemas/args.py @@ -133,14 +133,19 @@ class Args(BaseModel): return v @field_validator("train_resolution") - def validate_train_resolution(cls, v: str, info: ValidationInfo) -> str: + def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str: try: - # Parse resolution string "frames x height x width" - frames, height, width = map(int, v.split("x")) + frames, height, width = v - # Check if frames is multiple of 8 - if frames % 8 != 0: - raise ValueError("Number of frames must be a multiple of 8") + # Check if (frames - 1) is multiple of 8 + if (frames - 1) % 8 != 0: + raise ValueError("Number of frames - 1 must be a multiple of 8") + + # Check resolution for cogvideox-5b models + model_name = info.data.get("model_name", "") + if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]: + if (height, width) != (480, 720): + raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720") return v