diff --git a/finetune/schemas/__init__.py b/finetune/schemas/__init__.py new file mode 100644 index 0000000..76a6bf8 --- /dev/null +++ b/finetune/schemas/__init__.py @@ -0,0 +1,5 @@ +from .args import Args +from .state import State +from .components import Components + +__all__ = ["Args", "State", "Components"] \ No newline at end of file diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py new file mode 100644 index 0000000..4745f7c --- /dev/null +++ b/finetune/schemas/args.py @@ -0,0 +1,147 @@ +import datetime +import argparse +from typing import Dict, Any, Literal, List, Tuple +from pydantic import BaseModel, field_validator + +from pathlib import Path + + +class Args(BaseModel): + ########## Model ########## + model_path: Path + model_name: str + model_type: Literal["i2v", "t2v"] + training_type: Literal["lora", "sft"] = "lora" + + ########## Output ########## + output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now())) + report_to: Literal["tensorboard", "wandb", "all"] | None = None + tracker_name: str = "finetrainer-cogvideo" + + ########## Data ########### + data_root: Path + caption_column: Path + image_column: Path | None = None + video_column: Path + + ########## Training ######### + resume_from_checkpoint: Path | None = None + + seed: int | None = None + train_epochs: int + train_steps: int | None = None + checkpointing_steps: int = 500 + checkpointing_limit: int = 10 + + batch_size: int + gradient_accumulation_steps: int = 1 + + train_resolution: Tuple[int, int, int] # shape: (frames, height, width) + + #### deprecated args: video_resolution_buckets + # if use bucket for training, should not be None + # Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset + # Note2: For cogvideox, cogvideox1.5 + # The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8) + # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) + # video_resolution_buckets: List[Tuple[int, int, int]] | None = None + + mixed_precision: Literal["no", "fp16", "bf16"] + + learning_rate: float = 2e-5 + optimizer: str = "adamw" + beta1: float = 0.9 + beta2: float = 0.95 + beta3: float = 0.98 + epsilon: float = 1e-8 + weight_decay: float = 1e-4 + max_grad_norm: float = 1.0 + + lr_scheduler: str = "constant_with_warmup" + lr_warmup_steps: int = 100 + lr_num_cycles: int = 1 + lr_power: float = 1.0 + + num_workers: int = 8 + pin_memory: bool = True + + gradient_checkpointing: bool = True + enable_slicing: bool = True + enable_tiling: bool = True + nccl_timeout: int = 1800 + + ########## Lora ########## + rank: int = 128 + lora_alpha: int = 64 + target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"] + + ########## Validation ########## + do_validation: bool = False + validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps + validation_dir: Path | None # if set do_validation, should not be None + validation_prompts: str | None # if set do_validation, should not be None + validation_images: str | None # if set do_validation and model_type == i2v, should not be None + validation_videos: str | None # if set do_validation and model_type == v2v, should not be None + gen_fps: int = 15 + + #### deprecated args: gen_video_resolution + # 1. If set do_validation, should not be None + # 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning + # or the resolution recommended by the model + # 3. Note: For cogvideox, cogvideox1.5 + # The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8) + # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) + # gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width) + + + @classmethod + def parse_args(cls): + """Parse command line arguments and return Args instance""" + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--model_type", type=str, required=True) + parser.add_argument("--training_type", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--seed", type=int, required=True) + parser.add_argument("--nccl_timeout", type=int, required=True) + parser.add_argument("--mixed_precision", type=str, required=True) + parser.add_argument("--gradient_accumulation_steps", type=int, required=True) + parser.add_argument("--data_root", type=str, required=True) + parser.add_argument("--caption_column", type=str, required=True) + parser.add_argument("--video_column", type=str, required=True) + parser.add_argument("--image_column", type=str) + parser.add_argument("--train_resolution", type=str, required=True) + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--num_workers", type=int, required=True) + parser.add_argument("--pin_memory", type=str, required=True) + parser.add_argument("--report_to", type=str, required=True) + parser.add_argument("--train_epochs", type=int, required=True) + parser.add_argument("--checkpointing_steps", type=int, required=True) + parser.add_argument("--checkpointing_limit", type=int, required=True) + + parser.add_argument("--do_validation", type=bool) + parser.add_argument("--validation_steps", type=int) + parser.add_argument("--validation_dir", type=str) + parser.add_argument("--validation_prompts", type=str) + parser.add_argument("--validation_images", type=str) + parser.add_argument("--validation_videos", type=str) + parser.add_argument("--gen_fps", type=int) + + parser.add_argument("--resume_from_checkpoint", type=str) + + args = parser.parse_args() + + # Convert video_resolution_buckets string to list of tuples + frames, height, width = args.train_resolution.split("x") + args.train_resolution = (int(frames), int(height), int(width)) + + return cls(**vars(args)) + + # @field_validator("...", mode="after") + # def foo(cls, foobar): + # ... + + # @field_validator("...", mode="before") + # def bar(cls, barbar): + # ... diff --git a/finetune/schemas/components.py b/finetune/schemas/components.py new file mode 100644 index 0000000..2d3fef5 --- /dev/null +++ b/finetune/schemas/components.py @@ -0,0 +1,27 @@ +from typing import Any +from pydantic import BaseModel + + +class Components(BaseModel): + # pipeline cls + pipeline_cls: Any = None + + # Tokenizers + tokenizer: Any = None + tokenizer_2: Any = None + tokenizer_3: Any = None + + # Text encoders + text_encoder: Any = None + text_encoder_2: Any = None + text_encoder_3: Any = None + + # Autoencoder + vae: Any = None + + # Denoiser + transformer: Any = None + unet: Any = None + + # Scheduler + scheduler: Any = None diff --git a/finetune/schemas/state.py b/finetune/schemas/state.py new file mode 100644 index 0000000..9d5363c --- /dev/null +++ b/finetune/schemas/state.py @@ -0,0 +1,26 @@ +import torch + +from pathlib import Path +from typing import List, Dict, Any +from pydantic import BaseModel, field_validator + +class State(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + train_frames: int + train_height: int + train_width: int + + transformer_config: Dict[str, Any] = None + + weight_dtype: torch.dtype = torch.float32 + num_trainable_parameters: int = 0 + overwrote_max_train_steps: bool = False + num_update_steps_per_epoch: int = 0 + total_batch_size_count: int = 0 + + generator: torch.Generator | None = None + + validation_prompts: List[str] = [] + validation_images: List[Path | None] = [] + validation_videos: List[Path | None] = []