diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py index 4745f7c..dca76e3 100644 --- a/finetune/schemas/args.py +++ b/finetune/schemas/args.py @@ -1,7 +1,7 @@ import datetime import argparse from typing import Dict, Any, Literal, List, Tuple -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, ValidationInfo from pathlib import Path @@ -30,7 +30,7 @@ class Args(BaseModel): seed: int | None = None train_epochs: int train_steps: int | None = None - checkpointing_steps: int = 500 + checkpointing_steps: int = 200 checkpointing_limit: int = 10 batch_size: int @@ -93,42 +93,114 @@ class Args(BaseModel): # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) # gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width) + @field_validator("image_column") + def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None: + values = info.data + if values.get("model_type") == "i2v" and not v: + raise ValueError("image_column must be specified when using i2v model") + return v + + @field_validator("validation_dir", "validation_prompts") + def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any: + values = info.data + if values.get("do_validation") and not v: + field_name = info.field_name + raise ValueError(f"{field_name} must be specified when do_validation is True") + return v + + @field_validator("validation_images") + def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None: + values = info.data + if values.get("do_validation") and values.get("model_type") == "i2v" and not v: + raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v") + return v + + @field_validator("validation_videos") + def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None: + values = info.data + if values.get("do_validation") and values.get("model_type") == "v2v" and not v: + raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v") + return v + + @field_validator("validation_steps") + def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None: + values = info.data + if values.get("do_validation"): + if v is None: + raise ValueError("validation_steps must be specified when do_validation is True") + if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0: + raise ValueError("validation_steps must be a multiple of checkpointing_steps") + return v + @classmethod def parse_args(cls): """Parse command line arguments and return Args instance""" parser = argparse.ArgumentParser() + # Required arguments parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_name", type=str, required=True) parser.add_argument("--model_type", type=str, required=True) parser.add_argument("--training_type", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument("--seed", type=int, required=True) - parser.add_argument("--nccl_timeout", type=int, required=True) - parser.add_argument("--mixed_precision", type=str, required=True) - parser.add_argument("--gradient_accumulation_steps", type=int, required=True) parser.add_argument("--data_root", type=str, required=True) parser.add_argument("--caption_column", type=str, required=True) parser.add_argument("--video_column", type=str, required=True) - parser.add_argument("--image_column", type=str) parser.add_argument("--train_resolution", type=str, required=True) - parser.add_argument("--batch_size", type=int, required=True) - parser.add_argument("--num_workers", type=int, required=True) - parser.add_argument("--pin_memory", type=str, required=True) parser.add_argument("--report_to", type=str, required=True) - parser.add_argument("--train_epochs", type=int, required=True) - parser.add_argument("--checkpointing_steps", type=int, required=True) - parser.add_argument("--checkpointing_limit", type=int, required=True) - parser.add_argument("--do_validation", type=bool) - parser.add_argument("--validation_steps", type=int) - parser.add_argument("--validation_dir", type=str) - parser.add_argument("--validation_prompts", type=str) - parser.add_argument("--validation_images", type=str) - parser.add_argument("--validation_videos", type=str) - parser.add_argument("--gen_fps", type=int) + # Training hyperparameters + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--train_epochs", type=int, default=10) + parser.add_argument("--train_steps", type=int, default=None) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.95) + parser.add_argument("--beta3", type=float, default=0.98) + parser.add_argument("--epsilon", type=float, default=1e-8) + parser.add_argument("--weight_decay", type=float, default=1e-4) + parser.add_argument("--max_grad_norm", type=float, default=1.0) - parser.add_argument("--resume_from_checkpoint", type=str) + # Learning rate scheduler + parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=100) + parser.add_argument("--lr_num_cycles", type=int, default=1) + parser.add_argument("--lr_power", type=float, default=1.0) + + # Data loading + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--pin_memory", type=bool, default=True) + parser.add_argument("--image_column", type=str, default=None) + + # Model configuration + parser.add_argument("--mixed_precision", type=str, default="no") + parser.add_argument("--gradient_checkpointing", type=bool, default=True) + parser.add_argument("--enable_slicing", type=bool, default=True) + parser.add_argument("--enable_tiling", type=bool, default=True) + parser.add_argument("--nccl_timeout", type=int, default=1800) + + # LoRA parameters + parser.add_argument("--rank", type=int, default=128) + parser.add_argument("--lora_alpha", type=int, default=64) + parser.add_argument("--target_modules", type=str, nargs="+", + default=["to_q", "to_k", "to_v", "to_out.0"]) + + # Checkpointing + parser.add_argument("--checkpointing_steps", type=int, default=200) + parser.add_argument("--checkpointing_limit", type=int, default=10) + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + + # Validation + parser.add_argument("--do_validation", type=bool, default=False) + parser.add_argument("--validation_steps", type=int, default=None) + parser.add_argument("--validation_dir", type=str, default=None) + parser.add_argument("--validation_prompts", type=str, default=None) + parser.add_argument("--validation_images", type=str, default=None) + parser.add_argument("--validation_videos", type=str, default=None) + parser.add_argument("--gen_fps", type=int, default=15) args = parser.parse_args() @@ -137,11 +209,3 @@ class Args(BaseModel): args.train_resolution = (int(frames), int(height), int(width)) return cls(**vars(args)) - - # @field_validator("...", mode="after") - # def foo(cls, foobar): - # ... - - # @field_validator("...", mode="before") - # def bar(cls, barbar): - # ...