mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-19 20:33:13 +08:00
feat(args): add validation for training resolution
- Add validation check to ensure number of frames is multiple of 8 - Add format validation for train_resolution string (frames x height x width)
This commit is contained in:
parent
362b7bf273
commit
a88c1ede69
@ -131,6 +131,24 @@ class Args(BaseModel):
|
|||||||
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
||||||
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("train_resolution")
|
||||||
|
def validate_train_resolution(cls, v: str, info: ValidationInfo) -> str:
|
||||||
|
try:
|
||||||
|
# Parse resolution string "frames x height x width"
|
||||||
|
frames, height, width = map(int, v.split("x"))
|
||||||
|
|
||||||
|
# Check if frames is multiple of 8
|
||||||
|
if frames % 8 != 0:
|
||||||
|
raise ValueError("Number of frames must be a multiple of 8")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
if str(e) == "not enough values to unpack (expected 3, got 0)" or \
|
||||||
|
str(e) == "invalid literal for int() with base 10":
|
||||||
|
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
x
Reference in New Issue
Block a user