feat(args): add train_resolution validation for video frames and resolution

- Add validation to ensure (frames - 1) is multiple of 8
- Add specific resolution check (480x720) for cogvideox-5b models
- Add error handling for invalid resolution format
This commit is contained in:
OleehyO 2025-01-03 08:48:42 +00:00
parent ffb6ee36b4
commit de5bef6611

View File

@ -133,14 +133,19 @@ class Args(BaseModel):
return v
@field_validator("train_resolution")
def validate_train_resolution(cls, v: str, info: ValidationInfo) -> str:
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
try:
# Parse resolution string "frames x height x width"
frames, height, width = map(int, v.split("x"))
frames, height, width = v
# Check if frames is multiple of 8
if frames % 8 != 0:
raise ValueError("Number of frames must be a multiple of 8")
# Check if (frames - 1) is multiple of 8
if (frames - 1) % 8 != 0:
raise ValueError("Number of frames - 1 must be a multiple of 8")
# Check resolution for cogvideox-5b models
model_name = info.data.get("model_name", "")
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
if (height, width) != (480, 720):
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
return v