feat(args): add validation for training resolution

- Add validation check to ensure number of frames is multiple of 8
- Add format validation for train_resolution string (frames x height x width)
This commit is contained in:
OleehyO 2025-01-02 03:12:09 +00:00
parent 362b7bf273
commit a88c1ede69

View File

@ -131,6 +131,24 @@ class Args(BaseModel):
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
return v
@field_validator("train_resolution")
def validate_train_resolution(cls, v: str, info: ValidationInfo) -> str:
try:
# Parse resolution string "frames x height x width"
frames, height, width = map(int, v.split("x"))
# Check if frames is multiple of 8
if frames % 8 != 0:
raise ValueError("Number of frames must be a multiple of 8")
return v
except ValueError as e:
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
str(e) == "invalid literal for int() with base 10":
raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e
@classmethod