mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
feat(args): add validation and arg interface for training parameters
- Add field validators for model type and validation settings - Implement command line argument parsing with argparse - Add type hints and documentation for training parameters - Support configuration of model, training, and validation parameters
This commit is contained in:
parent
04a60e7435
commit
26b87cd4ff
@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Dict, Any, Literal, List, Tuple
|
from typing import Dict, Any, Literal, List, Tuple
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator, ValidationInfo
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class Args(BaseModel):
|
|||||||
seed: int | None = None
|
seed: int | None = None
|
||||||
train_epochs: int
|
train_epochs: int
|
||||||
train_steps: int | None = None
|
train_steps: int | None = None
|
||||||
checkpointing_steps: int = 500
|
checkpointing_steps: int = 200
|
||||||
checkpointing_limit: int = 10
|
checkpointing_limit: int = 10
|
||||||
|
|
||||||
batch_size: int
|
batch_size: int
|
||||||
@ -93,42 +93,114 @@ class Args(BaseModel):
|
|||||||
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
|
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
|
||||||
# gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width)
|
# gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width)
|
||||||
|
|
||||||
|
@field_validator("image_column")
|
||||||
|
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||||
|
values = info.data
|
||||||
|
if values.get("model_type") == "i2v" and not v:
|
||||||
|
raise ValueError("image_column must be specified when using i2v model")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("validation_dir", "validation_prompts")
|
||||||
|
def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any:
|
||||||
|
values = info.data
|
||||||
|
if values.get("do_validation") and not v:
|
||||||
|
field_name = info.field_name
|
||||||
|
raise ValueError(f"{field_name} must be specified when do_validation is True")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("validation_images")
|
||||||
|
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||||
|
values = info.data
|
||||||
|
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
|
||||||
|
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("validation_videos")
|
||||||
|
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||||
|
values = info.data
|
||||||
|
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
||||||
|
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("validation_steps")
|
||||||
|
def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None:
|
||||||
|
values = info.data
|
||||||
|
if values.get("do_validation"):
|
||||||
|
if v is None:
|
||||||
|
raise ValueError("validation_steps must be specified when do_validation is True")
|
||||||
|
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
||||||
|
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_args(cls):
|
def parse_args(cls):
|
||||||
"""Parse command line arguments and return Args instance"""
|
"""Parse command line arguments and return Args instance"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required arguments
|
||||||
parser.add_argument("--model_path", type=str, required=True)
|
parser.add_argument("--model_path", type=str, required=True)
|
||||||
parser.add_argument("--model_name", type=str, required=True)
|
parser.add_argument("--model_name", type=str, required=True)
|
||||||
parser.add_argument("--model_type", type=str, required=True)
|
parser.add_argument("--model_type", type=str, required=True)
|
||||||
parser.add_argument("--training_type", type=str, required=True)
|
parser.add_argument("--training_type", type=str, required=True)
|
||||||
parser.add_argument("--output_dir", type=str, required=True)
|
parser.add_argument("--output_dir", type=str, required=True)
|
||||||
parser.add_argument("--seed", type=int, required=True)
|
|
||||||
parser.add_argument("--nccl_timeout", type=int, required=True)
|
|
||||||
parser.add_argument("--mixed_precision", type=str, required=True)
|
|
||||||
parser.add_argument("--gradient_accumulation_steps", type=int, required=True)
|
|
||||||
parser.add_argument("--data_root", type=str, required=True)
|
parser.add_argument("--data_root", type=str, required=True)
|
||||||
parser.add_argument("--caption_column", type=str, required=True)
|
parser.add_argument("--caption_column", type=str, required=True)
|
||||||
parser.add_argument("--video_column", type=str, required=True)
|
parser.add_argument("--video_column", type=str, required=True)
|
||||||
parser.add_argument("--image_column", type=str)
|
|
||||||
parser.add_argument("--train_resolution", type=str, required=True)
|
parser.add_argument("--train_resolution", type=str, required=True)
|
||||||
parser.add_argument("--batch_size", type=int, required=True)
|
|
||||||
parser.add_argument("--num_workers", type=int, required=True)
|
|
||||||
parser.add_argument("--pin_memory", type=str, required=True)
|
|
||||||
parser.add_argument("--report_to", type=str, required=True)
|
parser.add_argument("--report_to", type=str, required=True)
|
||||||
parser.add_argument("--train_epochs", type=int, required=True)
|
|
||||||
parser.add_argument("--checkpointing_steps", type=int, required=True)
|
|
||||||
parser.add_argument("--checkpointing_limit", type=int, required=True)
|
|
||||||
|
|
||||||
parser.add_argument("--do_validation", type=bool)
|
# Training hyperparameters
|
||||||
parser.add_argument("--validation_steps", type=int)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
parser.add_argument("--validation_dir", type=str)
|
parser.add_argument("--train_epochs", type=int, default=10)
|
||||||
parser.add_argument("--validation_prompts", type=str)
|
parser.add_argument("--train_steps", type=int, default=None)
|
||||||
parser.add_argument("--validation_images", type=str)
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||||
parser.add_argument("--validation_videos", type=str)
|
parser.add_argument("--batch_size", type=int, default=1)
|
||||||
parser.add_argument("--gen_fps", type=int)
|
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
||||||
|
parser.add_argument("--optimizer", type=str, default="adamw")
|
||||||
|
parser.add_argument("--beta1", type=float, default=0.9)
|
||||||
|
parser.add_argument("--beta2", type=float, default=0.95)
|
||||||
|
parser.add_argument("--beta3", type=float, default=0.98)
|
||||||
|
parser.add_argument("--epsilon", type=float, default=1e-8)
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=1e-4)
|
||||||
|
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
||||||
|
|
||||||
parser.add_argument("--resume_from_checkpoint", type=str)
|
# Learning rate scheduler
|
||||||
|
parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup")
|
||||||
|
parser.add_argument("--lr_warmup_steps", type=int, default=100)
|
||||||
|
parser.add_argument("--lr_num_cycles", type=int, default=1)
|
||||||
|
parser.add_argument("--lr_power", type=float, default=1.0)
|
||||||
|
|
||||||
|
# Data loading
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8)
|
||||||
|
parser.add_argument("--pin_memory", type=bool, default=True)
|
||||||
|
parser.add_argument("--image_column", type=str, default=None)
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
parser.add_argument("--mixed_precision", type=str, default="no")
|
||||||
|
parser.add_argument("--gradient_checkpointing", type=bool, default=True)
|
||||||
|
parser.add_argument("--enable_slicing", type=bool, default=True)
|
||||||
|
parser.add_argument("--enable_tiling", type=bool, default=True)
|
||||||
|
parser.add_argument("--nccl_timeout", type=int, default=1800)
|
||||||
|
|
||||||
|
# LoRA parameters
|
||||||
|
parser.add_argument("--rank", type=int, default=128)
|
||||||
|
parser.add_argument("--lora_alpha", type=int, default=64)
|
||||||
|
parser.add_argument("--target_modules", type=str, nargs="+",
|
||||||
|
default=["to_q", "to_k", "to_v", "to_out.0"])
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
||||||
|
parser.add_argument("--checkpointing_limit", type=int, default=10)
|
||||||
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
parser.add_argument("--do_validation", type=bool, default=False)
|
||||||
|
parser.add_argument("--validation_steps", type=int, default=None)
|
||||||
|
parser.add_argument("--validation_dir", type=str, default=None)
|
||||||
|
parser.add_argument("--validation_prompts", type=str, default=None)
|
||||||
|
parser.add_argument("--validation_images", type=str, default=None)
|
||||||
|
parser.add_argument("--validation_videos", type=str, default=None)
|
||||||
|
parser.add_argument("--gen_fps", type=int, default=15)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -137,11 +209,3 @@ class Args(BaseModel):
|
|||||||
args.train_resolution = (int(frames), int(height), int(width))
|
args.train_resolution = (int(frames), int(height), int(width))
|
||||||
|
|
||||||
return cls(**vars(args))
|
return cls(**vars(args))
|
||||||
|
|
||||||
# @field_validator("...", mode="after")
|
|
||||||
# def foo(cls, foobar):
|
|
||||||
# ...
|
|
||||||
|
|
||||||
# @field_validator("...", mode="before")
|
|
||||||
# def bar(cls, barbar):
|
|
||||||
# ...
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user