mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-21 22:00:00 +08:00
feat(args): add train_resolution validation for video frames and resolution
- Add validation to ensure (frames - 1) is multiple of 8 - Add specific resolution check (480x720) for cogvideox-5b models - Add error handling for invalid resolution format
This commit is contained in:
parent
ffb6ee36b4
commit
de5bef6611
@ -133,14 +133,19 @@ class Args(BaseModel):
|
||||
return v
|
||||
|
||||
@field_validator("train_resolution")
|
||||
def validate_train_resolution(cls, v: str, info: ValidationInfo) -> str:
|
||||
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
|
||||
try:
|
||||
# Parse resolution string "frames x height x width"
|
||||
frames, height, width = map(int, v.split("x"))
|
||||
frames, height, width = v
|
||||
|
||||
# Check if frames is multiple of 8
|
||||
if frames % 8 != 0:
|
||||
raise ValueError("Number of frames must be a multiple of 8")
|
||||
# Check if (frames - 1) is multiple of 8
|
||||
if (frames - 1) % 8 != 0:
|
||||
raise ValueError("Number of frames - 1 must be a multiple of 8")
|
||||
|
||||
# Check resolution for cogvideox-5b models
|
||||
model_name = info.data.get("model_name", "")
|
||||
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
||||
if (height, width) != (480, 720):
|
||||
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
|
||||
|
||||
return v
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user