mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-14 20:09:18 +08:00
feat(args): add validation for training resolution
- Add validation check to ensure number of frames is multiple of 8 - Add format validation for train_resolution string (frames x height x width)
This commit is contained in:
parent
362b7bf273
commit
a88c1ede69
@ -131,6 +131,24 @@ class Args(BaseModel):
|
||||
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
||||
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
||||
return v
|
||||
|
||||
@field_validator("train_resolution")
|
||||
def validate_train_resolution(cls, v: str, info: ValidationInfo) -> str:
|
||||
try:
|
||||
# Parse resolution string "frames x height x width"
|
||||
frames, height, width = map(int, v.split("x"))
|
||||
|
||||
# Check if frames is multiple of 8
|
||||
if frames % 8 != 0:
|
||||
raise ValueError("Number of frames must be a multiple of 8")
|
||||
|
||||
return v
|
||||
|
||||
except ValueError as e:
|
||||
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
|
||||
str(e) == "invalid literal for int() with base 10":
|
||||
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
||||
raise e
|
||||
|
||||
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user