mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat: add schemas module for configuration and state management
Add Pydantic models to handle: - CLI arguments and configuration (Args) - Model components and pipeline (Components) - Training state and parameters (State)
This commit is contained in:
parent
45d40450a1
commit
91d79fd9a4
5
finetune/schemas/__init__.py
Normal file
5
finetune/schemas/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .args import Args
|
||||
from .state import State
|
||||
from .components import Components
|
||||
|
||||
__all__ = ["Args", "State", "Components"]
|
147
finetune/schemas/args.py
Normal file
147
finetune/schemas/args.py
Normal file
@ -0,0 +1,147 @@
|
||||
import datetime
|
||||
import argparse
|
||||
from typing import Dict, Any, Literal, List, Tuple
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Args(BaseModel):
|
||||
########## Model ##########
|
||||
model_path: Path
|
||||
model_name: str
|
||||
model_type: Literal["i2v", "t2v"]
|
||||
training_type: Literal["lora", "sft"] = "lora"
|
||||
|
||||
########## Output ##########
|
||||
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now()))
|
||||
report_to: Literal["tensorboard", "wandb", "all"] | None = None
|
||||
tracker_name: str = "finetrainer-cogvideo"
|
||||
|
||||
########## Data ###########
|
||||
data_root: Path
|
||||
caption_column: Path
|
||||
image_column: Path | None = None
|
||||
video_column: Path
|
||||
|
||||
########## Training #########
|
||||
resume_from_checkpoint: Path | None = None
|
||||
|
||||
seed: int | None = None
|
||||
train_epochs: int
|
||||
train_steps: int | None = None
|
||||
checkpointing_steps: int = 500
|
||||
checkpointing_limit: int = 10
|
||||
|
||||
batch_size: int
|
||||
gradient_accumulation_steps: int = 1
|
||||
|
||||
train_resolution: Tuple[int, int, int] # shape: (frames, height, width)
|
||||
|
||||
#### deprecated args: video_resolution_buckets
|
||||
# if use bucket for training, should not be None
|
||||
# Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset
|
||||
# Note2: For cogvideox, cogvideox1.5
|
||||
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
|
||||
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
|
||||
# video_resolution_buckets: List[Tuple[int, int, int]] | None = None
|
||||
|
||||
mixed_precision: Literal["no", "fp16", "bf16"]
|
||||
|
||||
learning_rate: float = 2e-5
|
||||
optimizer: str = "adamw"
|
||||
beta1: float = 0.9
|
||||
beta2: float = 0.95
|
||||
beta3: float = 0.98
|
||||
epsilon: float = 1e-8
|
||||
weight_decay: float = 1e-4
|
||||
max_grad_norm: float = 1.0
|
||||
|
||||
lr_scheduler: str = "constant_with_warmup"
|
||||
lr_warmup_steps: int = 100
|
||||
lr_num_cycles: int = 1
|
||||
lr_power: float = 1.0
|
||||
|
||||
num_workers: int = 8
|
||||
pin_memory: bool = True
|
||||
|
||||
gradient_checkpointing: bool = True
|
||||
enable_slicing: bool = True
|
||||
enable_tiling: bool = True
|
||||
nccl_timeout: int = 1800
|
||||
|
||||
########## Lora ##########
|
||||
rank: int = 128
|
||||
lora_alpha: int = 64
|
||||
target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
########## Validation ##########
|
||||
do_validation: bool = False
|
||||
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
|
||||
validation_dir: Path | None # if set do_validation, should not be None
|
||||
validation_prompts: str | None # if set do_validation, should not be None
|
||||
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
||||
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
||||
gen_fps: int = 15
|
||||
|
||||
#### deprecated args: gen_video_resolution
|
||||
# 1. If set do_validation, should not be None
|
||||
# 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning
|
||||
# or the resolution recommended by the model
|
||||
# 3. Note: For cogvideox, cogvideox1.5
|
||||
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
|
||||
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
|
||||
# gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width)
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse_args(cls):
|
||||
"""Parse command line arguments and return Args instance"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, required=True)
|
||||
parser.add_argument("--model_name", type=str, required=True)
|
||||
parser.add_argument("--model_type", type=str, required=True)
|
||||
parser.add_argument("--training_type", type=str, required=True)
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument("--seed", type=int, required=True)
|
||||
parser.add_argument("--nccl_timeout", type=int, required=True)
|
||||
parser.add_argument("--mixed_precision", type=str, required=True)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, required=True)
|
||||
parser.add_argument("--data_root", type=str, required=True)
|
||||
parser.add_argument("--caption_column", type=str, required=True)
|
||||
parser.add_argument("--video_column", type=str, required=True)
|
||||
parser.add_argument("--image_column", type=str)
|
||||
parser.add_argument("--train_resolution", type=str, required=True)
|
||||
parser.add_argument("--batch_size", type=int, required=True)
|
||||
parser.add_argument("--num_workers", type=int, required=True)
|
||||
parser.add_argument("--pin_memory", type=str, required=True)
|
||||
parser.add_argument("--report_to", type=str, required=True)
|
||||
parser.add_argument("--train_epochs", type=int, required=True)
|
||||
parser.add_argument("--checkpointing_steps", type=int, required=True)
|
||||
parser.add_argument("--checkpointing_limit", type=int, required=True)
|
||||
|
||||
parser.add_argument("--do_validation", type=bool)
|
||||
parser.add_argument("--validation_steps", type=int)
|
||||
parser.add_argument("--validation_dir", type=str)
|
||||
parser.add_argument("--validation_prompts", type=str)
|
||||
parser.add_argument("--validation_images", type=str)
|
||||
parser.add_argument("--validation_videos", type=str)
|
||||
parser.add_argument("--gen_fps", type=int)
|
||||
|
||||
parser.add_argument("--resume_from_checkpoint", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert video_resolution_buckets string to list of tuples
|
||||
frames, height, width = args.train_resolution.split("x")
|
||||
args.train_resolution = (int(frames), int(height), int(width))
|
||||
|
||||
return cls(**vars(args))
|
||||
|
||||
# @field_validator("...", mode="after")
|
||||
# def foo(cls, foobar):
|
||||
# ...
|
||||
|
||||
# @field_validator("...", mode="before")
|
||||
# def bar(cls, barbar):
|
||||
# ...
|
27
finetune/schemas/components.py
Normal file
27
finetune/schemas/components.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Components(BaseModel):
|
||||
# pipeline cls
|
||||
pipeline_cls: Any = None
|
||||
|
||||
# Tokenizers
|
||||
tokenizer: Any = None
|
||||
tokenizer_2: Any = None
|
||||
tokenizer_3: Any = None
|
||||
|
||||
# Text encoders
|
||||
text_encoder: Any = None
|
||||
text_encoder_2: Any = None
|
||||
text_encoder_3: Any = None
|
||||
|
||||
# Autoencoder
|
||||
vae: Any = None
|
||||
|
||||
# Denoiser
|
||||
transformer: Any = None
|
||||
unet: Any = None
|
||||
|
||||
# Scheduler
|
||||
scheduler: Any = None
|
26
finetune/schemas/state.py
Normal file
26
finetune/schemas/state.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
class State(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
train_frames: int
|
||||
train_height: int
|
||||
train_width: int
|
||||
|
||||
transformer_config: Dict[str, Any] = None
|
||||
|
||||
weight_dtype: torch.dtype = torch.float32
|
||||
num_trainable_parameters: int = 0
|
||||
overwrote_max_train_steps: bool = False
|
||||
num_update_steps_per_epoch: int = 0
|
||||
total_batch_size_count: int = 0
|
||||
|
||||
generator: torch.Generator | None = None
|
||||
|
||||
validation_prompts: List[str] = []
|
||||
validation_images: List[Path | None] = []
|
||||
validation_videos: List[Path | None] = []
|
Loading…
x
Reference in New Issue
Block a user