Merge pull request #632 from THUDM/CogVideoX_dev

Refactored the training code of finetune
This commit is contained in:
Yuxuan Zhang 2025-01-02 08:31:25 +08:00 committed by GitHub
commit aa240dc675
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 2860 additions and 3435 deletions

15
.gitignore vendored
View File

@ -8,4 +8,17 @@ logs/
.idea
output*
test*
venv
venv
**/.swp
**/*.log
**/*.debug
**/.vscode
**/*debug*
**/.gitignore
**/finetune/*-lora-*
**/finetune/Disney-*
**/wandb
**/results
**/*.mp4
**/validation_set

View File

@ -1,26 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
deepspeed_hostfile: hostfile.txt
deepspeed_multinode_launcher: pdsh
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
enable_cpu_affinity: true
main_process_ip: 10.250.128.19
main_process_port: 12355
main_training_function: main
mixed_precision: bf16
num_machines: 4
num_processes: 32
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,20 +0,0 @@
compute_environment: LOCAL_MACHINE
gpu_ids: "0"
debug: false
deepspeed_config:
deepspeed_config_file: ds_config.json
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,69 @@
#!/usr/bin/env bash
# Prevent tokenizer parallelism issues
export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B-I2V"
--model_name "cogvideox1.5-i2v"
--model_type "i2v"
--training_type "lora"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/path/to/output/dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/path/to/data/dir"
--caption_column "prompt.txt"
--video_column "videos.txt"
--image_column "images.txt"
--train_resolution "80x768x1360"
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16"
--seed 42
)
# System Configuration
SYSTEM_ARGS=(
--num_workers 8
--pin_memory True
--nccl_timeout 1800
)
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 200
--checkpointing_limit 10
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation False
--validation_dir "/path/to/validation/dir"
--validation_steps 400
--validation_prompts "prompts.txt"
--validation_images "images.txt"
--gen_fps 15
)
# Combine all arguments and launch training
accelerate launch train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \
"${TRAIN_ARGS[@]}" \
"${SYSTEM_ARGS[@]}" \
"${CHECKPOINT_ARGS[@]}" \
"${VALIDATION_ARGS[@]}"

View File

@ -0,0 +1,67 @@
#!/usr/bin/env bash
# Prevent tokenizer parallelism issues
export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B"
--model_name "cogvideox1.5-t2v"
--model_type "t2v"
--training_type "lora"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/path/to/output/dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/path/to/data/dir"
--caption_column "prompt.txt"
--video_column "videos.txt"
--train_resolution "80x768x1360"
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16"
--seed 42
)
# System Configuration
SYSTEM_ARGS=(
--num_workers 8
--pin_memory True
--nccl_timeout 1800
)
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 200
--checkpointing_limit 10
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation False
--validation_dir "/path/to/validation/dir"
--validation_steps 400
--validation_prompts "prompts.txt"
--gen_fps 15
)
# Combine all arguments and launch training
accelerate launch train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \
"${TRAIN_ARGS[@]}" \
"${SYSTEM_ARGS[@]}" \
"${CHECKPOINT_ARGS[@]}" \
"${VALIDATION_ARGS[@]}"

2
finetune/constants.py Normal file
View File

@ -0,0 +1,2 @@
LOG_NAME = "trainer"
LOG_LEVEL = "INFO"

View File

@ -0,0 +1,12 @@
from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
from .bucket_sampler import BucketSampler
__all__ = [
"I2VDatasetWithResize",
"I2VDatasetWithBuckets",
"T2VDatasetWithResize",
"T2VDatasetWithBuckets",
"BucketSampler"
]

View File

@ -0,0 +1,73 @@
import random
import logging
from torch.utils.data import Sampler
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)
class BucketSampler(Sampler):
r"""
PyTorch Sampler that groups 3D data by height, width and frames.
Args:
data_source (`VideoDataset`):
A PyTorch dataset object that is an instance of `VideoDataset`.
batch_size (`int`, defaults to `8`):
The batch size to use for training.
shuffle (`bool`, defaults to `True`):
Whether or not to shuffle the data in each batch before dispatching to dataloader.
drop_last (`bool`, defaults to `False`):
Whether or not to drop incomplete buckets of data after completely iterating over all data
in the dataset. If set to True, only batches that have `batch_size` number of entries will
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
and batches that do not have `batch_size` number of entries will also be yielded.
"""
def __init__(
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
) -> None:
self.data_source = data_source
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets}
self._raised_warning_for_drop_last = False
def __len__(self):
if self.drop_last and not self._raised_warning_for_drop_last:
self._raised_warning_for_drop_last = True
logger.warning(
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
)
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
def __iter__(self):
for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"]
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
self.buckets[(f, h, w)].append(data)
if len(self.buckets[(f, h, w)]) == self.batch_size:
if self.shuffle:
random.shuffle(self.buckets[(f, h, w)])
yield self.buckets[(f, h, w)]
del self.buckets[(f, h, w)]
self.buckets[(f, h, w)] = []
if self.drop_last:
return
for fhw, bucket in list(self.buckets.items()):
if len(bucket) == 0:
continue
if self.shuffle:
random.shuffle(bucket)
yield bucket
del self.buckets[fhw]
self.buckets[fhw] = []

View File

@ -0,0 +1,275 @@
import torch
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable
from typing_extensions import override
from accelerate.logging import get_logger
from torch.utils.data import Dataset
from torchvision import transforms
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos, load_images,
preprocess_image_with_resize,
preprocess_video_with_resize,
preprocess_video_with_buckets
)
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger(LOG_NAME, LOG_LEVEL)
class BaseI2VDataset(Dataset):
"""
Base dataset class for Image-to-Video (I2V) training.
This dataset loads prompts, videos and corresponding conditioning images for I2V training.
Args:
data_root (str): Root directory containing the dataset files
caption_column (str): Path to file containing text prompts/captions
video_column (str): Path to file containing video paths
image_column (str): Path to file containing image paths
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
caption_column: str,
video_column: str,
image_column: str,
device: torch.device,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
*args,
**kwargs
) -> None:
super().__init__()
data_root = Path(data_root)
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.images = load_images(data_root / image_column)
self.device = device
self.encode_video_fn = encode_video_fn
# Check if number of prompts matches number of videos and images
if not (len(self.videos) == len(self.prompts) == len(self.images)):
raise ValueError(
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
)
# Check if all video files exist
if any(not path.is_file() for path in self.videos):
raise ValueError(
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
)
# Check if all image files exist
if any(not path.is_file() for path in self.images):
raise ValueError(
f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
)
def __len__(self) -> int:
return len(self.videos)
def __getitem__(self, index: int) -> Dict[str, Any]:
if isinstance(index, list):
# Here, index is actually a list of data objects that we need to return.
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
# to have information about num_frames, height and width. Since this is not stored
# as metadata, we need to read the video to get this information. You could read this
# information without loading the full video in memory, but we do it anyway. In order
# to not load the video twice (once to get the metadata, and once to return the loaded video
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
# to yield, we yield the cache data instead of indices. So, this special check ensures
# that data is not loaded a second time. PRs are welcome for improvements.
return index
prompt = self.prompts[index]
video = self.videos[index]
image = self.images[index]
video_latent_dir = video.parent / "latent"
video_latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = video_latent_dir / (video.stem + ".pt")
if encoded_video_path.exists():
encoded_video = torch.load(encoded_video_path, weights_only=True)
# shape of image: [C, H, W]
_, image = self.preprocess(None, self.images[index])
else:
frames, image = self.preprocess(video, image)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
# shape of image: [C, H, W]
return {
"prompt": prompt,
"image": image,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],
"height": encoded_video.shape[2],
"width": encoded_video.shape[3],
},
}
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Loads and preprocesses a video and an image.
If either path is None, no preprocessing will be done for that input.
Args:
video_path: Path to the video file to load
image_path: Path to the image file to load
Returns:
A tuple containing:
- video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
C is number of channels, H is height and W is width
- image(torch.Tensor) of shape [C, H, W]
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
Args:
frames (torch.Tensor): A 4D tensor representing a video
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- W is width
Returns:
torch.Tensor: The transformed video tensor
"""
raise NotImplementedError("Subclass must implement this method")
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to an image.
Args:
image (torch.Tensor): A 3D tensor representing an image
with shape [C, H, W] where:
- C is number of channels (3 for RGB)
- H is height
- W is width
Returns:
torch.Tensor: The transformed image tensor
"""
raise NotImplementedError("Subclass must implement this method")
class I2VDatasetWithResize(BaseI2VDataset):
"""
A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
This class preprocesses videos and images by resizing them to specified dimensions:
- Videos are resized to max_num_frames x height x width
- Images are resized to height x width
Args:
max_num_frames (int): Maximum number of frames to extract from videos
height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.max_num_frames = max_num_frames
self.height = height
self.width = width
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__image_transforms = self.__frame_transforms
@override
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
if video_path is not None:
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
else:
video = None
if image_path is not None:
image = preprocess_image_with_resize(image_path, self.height, self.width)
else:
image = None
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)
class I2VDatasetWithBuckets(BaseI2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.video_resolution_buckets = [
(
int(b[0] / vae_temporal_compression_ratio),
int(b[1] / vae_height_compression_ratio),
int(b[2] / vae_width_compression_ratio),
)
for b in video_resolution_buckets
]
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__image_transforms = self.__frame_transforms
@override
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)

View File

@ -0,0 +1,231 @@
import torch
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable
from typing_extensions import override
from accelerate.logging import get_logger
from torch.utils.data import Dataset
from torchvision import transforms
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos,
preprocess_video_with_resize,
preprocess_video_with_buckets
)
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger(LOG_NAME, LOG_LEVEL)
class BaseT2VDataset(Dataset):
"""
Base dataset class for Text-to-Video (T2V) training.
This dataset loads prompts and videos for T2V training.
Args:
data_root (str): Root directory containing the dataset files
caption_column (str): Path to file containing text prompts/captions
video_column (str): Path to file containing video paths
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
caption_column: str,
video_column: str,
device: torch.device = None,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
*args,
**kwargs
) -> None:
super().__init__()
data_root = Path(data_root)
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.device = device
self.encode_video_fn = encode_video_fn
# Check if all video files exist
if any(not path.is_file() for path in self.videos):
raise ValueError(
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
)
# Check if number of prompts matches number of videos
if len(self.videos) != len(self.prompts):
raise ValueError(
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
)
def __len__(self) -> int:
return len(self.videos)
def __getitem__(self, index: int) -> Dict[str, Any]:
if isinstance(index, list):
# Here, index is actually a list of data objects that we need to return.
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
# to have information about num_frames, height and width. Since this is not stored
# as metadata, we need to read the video to get this information. You could read this
# information without loading the full video in memory, but we do it anyway. In order
# to not load the video twice (once to get the metadata, and once to return the loaded video
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
# to yield, we yield the cache data instead of indices. So, this special check ensures
# that data is not loaded a second time. PRs are welcome for improvements.
return index
prompt = self.prompts[index]
video = self.videos[index]
latent_dir = video.parent / "latent"
latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = latent_dir / (video.stem + ".pt")
if encoded_video_path.exists():
# shape of encoded_video: [C, F, H, W]
encoded_video = torch.load(encoded_video_path, weights_only=True)
else:
frames = self.preprocess(video)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
return {
"prompt": prompt,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],
"height": encoded_video.shape[2],
"width": encoded_video.shape[3],
},
}
def preprocess(self, video_path: Path) -> torch.Tensor:
"""
Loads and preprocesses a video.
Args:
video_path: Path to the video file to load.
Returns:
torch.Tensor: Video tensor of shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- W is width
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
Args:
frames (torch.Tensor): A 4D tensor representing a video
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- W is width
Returns:
torch.Tensor: The transformed video tensor with the same shape as the input
"""
raise NotImplementedError("Subclass must implement this method")
class T2VDatasetWithResize(BaseT2VDataset):
"""
A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
This class preprocesses videos by resizing them to specified dimensions:
- Videos are resized to max_num_frames x height x width
Args:
max_num_frames (int): Maximum number of frames to extract from videos
height (int): Target height for resizing videos
width (int): Target width for resizing videos
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.max_num_frames = max_num_frames
self.height = height
self.width = width
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width,
)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
class T2VDatasetWithBuckets(BaseT2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
) -> None:
"""
"""
super().__init__(*args, **kwargs)
self.video_resolution_buckets = [
(
int(b[0] / vae_temporal_compression_ratio),
int(b[1] / vae_height_compression_ratio),
int(b[2] / vae_width_compression_ratio),
)
for b in video_resolution_buckets
]
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_buckets(
video_path, self.video_resolution_buckets
)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)

147
finetune/datasets/utils.py Normal file
View File

@ -0,0 +1,147 @@
import torch
import cv2
from typing import List, Tuple
from pathlib import Path
from torchvision import transforms
from torchvision.transforms.functional import resize
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
########## loaders ##########
def load_prompts(prompt_path: Path) -> List[str]:
with open(prompt_path, "r", encoding="utf-8") as file:
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
def load_videos(video_path: Path) -> List[Path]:
with open(video_path, "r", encoding="utf-8") as file:
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
def load_images(image_path: Path) -> List[Path]:
with open(image_path, "r", encoding="utf-8") as file:
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
########## preprocessors ##########
def preprocess_image_with_resize(
image_path: Path | str,
height: int,
width: int,
) -> torch.Tensor:
"""
Loads and resizes a single image.
Args:
image_path: Path to the image file.
height: Target height for resizing.
width: Target width for resizing.
Returns:
torch.Tensor: Image tensor with shape [C, H, W] where:
C = number of channels (3 for RGB)
H = height
W = width
"""
if isinstance(image_path, str):
image_path = Path(image_path)
image = cv2.imread(image_path.as_posix())
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (width, height))
image = torch.from_numpy(image).float()
image = image.permute(2, 0, 1).contiguous()
return image
def preprocess_video_with_resize(
video_path: Path | str,
max_num_frames: int,
height: int,
width: int,
) -> torch.Tensor:
"""
Loads and resizes a single video.
The function processes the video through these steps:
1. If video frame count > max_num_frames, downsample frames evenly
2. If video dimensions don't match (height, width), resize frames
Args:
video_path: Path to the video file.
max_num_frames: Maximum number of frames to keep.
height: Target height for resizing.
width: Target width for resizing.
Returns:
A torch.Tensor with shape [F, C, H, W] where:
F = number of frames
C = number of channels (3 for RGB)
H = height
W = width
"""
if isinstance(video_path, str):
video_path = Path(video_path)
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
video_num_frames = len(video_reader)
if video_num_frames < max_num_frames:
raise ValueError(f"video's frames is less than {max_num_frames}.")
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)
frames = frames[: max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
return frames
def preprocess_video_with_buckets(
video_path: Path,
resolution_buckets: List[Tuple[int, int, int]],
) -> torch.Tensor:
"""
Args:
video_path: Path to the video file.
resolution_buckets: List of tuples (num_frames, height, width) representing
available resolution buckets.
Returns:
torch.Tensor: Video tensor with shape [F, C, H, W] where:
F = number of frames
C = number of channels (3 for RGB)
H = height
W = width
The function processes the video through these steps:
1. Finds nearest frame bucket <= video frame count
2. Downsamples frames evenly to match bucket size
3. Finds nearest resolution bucket based on dimensions
4. Resizes frames to match bucket resolution
"""
video_reader = decord.VideoReader(uri=video_path.as_posix())
video_num_frames = len(video_reader)
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
if len(resolution_buckets) == 0:
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
nearest_frame_bucket = min(
resolution_buckets,
key=lambda bucket: video_num_frames - bucket[0],
default=1,
)[0]
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
frames = video_reader.get_batch(frame_indices)
frames = frames[:nearest_frame_bucket].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
nearest_res = (nearest_res[1], nearest_res[2])
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
return frames

View File

@ -1,20 +0,0 @@
{
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e8,
"contiguous_gradients": true
}
}

View File

@ -1,52 +0,0 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-multi-node"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# max batch-size is 2.
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompt.txt \
--video_column videos.txt \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 100 \
--seed 42 \
--rank 128 \
--lora_alpha 64 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb

View File

@ -1,52 +0,0 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-5b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-single-node"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single.yaml \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompt.txt \
--video_column videos.txt \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 100 \
--seed 42 \
--rank 128 \
--lora_alpha 64 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb

View File

@ -1,2 +0,0 @@
node1 slots=8
node2 slots=8

View File

@ -0,0 +1,12 @@
import importlib
from pathlib import Path
package_dir = Path(__file__).parent
for subdir in package_dir.iterdir():
if subdir.is_dir() and not subdir.name.startswith('_'):
for module_path in subdir.glob('*.py'):
module_name = module_path.stem
full_module_name = f".{subdir.name}.{module_name}"
importlib.import_module(full_module_name, package=__name__)

View File

@ -0,0 +1,9 @@
from ..utils import register
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
pass
register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
class CogVideoX1dot5T2VLoraTrainer(CogVideoXT2VLoraTrainer):
pass
register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer)

View File

@ -0,0 +1,240 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers import (
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
)
from finetune.trainer import Trainer
from finetune.schemas import Components
from finetune.utils import unwrap_model
from ..utils import register
class CogVideoXI2VLoraTrainer(Trainer):
@override
def load_components(self) -> Dict[str, Any]:
components = Components()
model_path = str(self.args.model_path)
components.pipeline_cls = CogVideoXImageToVideoPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
vae = self.components.vae
video = video.to(vae.device, dtype=vae.dtype)
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_token_ids": [],
"images": []
}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt = sample["prompt"]
image = sample["image"]
# tokenize prompt
text_inputs = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
ret["encoded_videos"].append(encoded_video)
ret["prompt_token_ids"].append(text_input_ids[0])
ret["images"].append(image)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
ret["images"] = torch.stack(ret["images"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_token_ids = batch["prompt_token_ids"]
latent = batch["encoded_videos"]
images = batch["images"]
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
images = images.unsqueeze(2)
# Add noise to images
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = self.components.vae.encode(noisy_images).latent_dist
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
# from [B, C, F, H, W] to [B, F, C, H, W]
latent = latent.permute(0, 2, 1, 3, 4)
image_latents = image_latents.permute(0, 2, 1, 3, 4)
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:])
# Padding image_latents to the same frame number as latent
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
latent_padding = image_latents.new_zeros(padding_shape)
image_latents = torch.cat([image_latents, latent_padding], dim=1)
# Add noise to latent
noise = torch.randn_like(latent)
latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
# Concatenate latent and image_latents in the channel dimension
latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
# Prepare rotary embeds
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
transformer_config = self.state.transformer_config
rotary_emb = (
self.prepare_rotary_positional_embeddings(
height=height * vae_scale_factor_spatial,
width=width * vae_scale_factor_spatial,
num_frames=num_frames,
transformer_config=transformer_config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
device=self.accelerator.device,
)
if transformer_config.use_rotary_positional_embeddings
else None
)
# Predict noise
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
ofs=ofs_emb,
image_rotary_emb=rotary_emb,
return_dict=False,
)[0]
# Denoise
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
while len(weights.shape) < len(latent_pred.shape):
weights = weights.unsqueeze(-1)
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
return loss
@override
def validation_step(
self, eval_data: Dict[str, Any]
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
"""
Return the data that needs to be saved. For videos, the data format is List[PIL],
and for images, the data format is PIL
"""
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
pipe = self.components.pipeline_cls(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
)
video_generate = pipe(
num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
image=image,
generator=self.state.generator
).frames[0]
return [("video", video_generate)]
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(grid_height, grid_width),
device=device,
)
return freqs_cos, freqs_sin
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)

View File

@ -0,0 +1,214 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers import (
CogVideoXPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
)
from finetune.trainer import Trainer
from finetune.schemas import Components
from finetune.utils import unwrap_model
from ..utils import register
class CogVideoXT2VLoraTrainer(Trainer):
@override
def load_components(self) -> Components:
components = Components()
model_path = str(self.args.model_path)
components.pipeline_cls = CogVideoXPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
vae = self.components.vae
video = video.to(vae.device, dtype=vae.dtype)
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_token_ids": []
}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt = sample["prompt"]
# tokenize prompt
text_inputs = self.components.tokenizer(
prompt,
padding="max_length",
max_length=226,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
ret["encoded_videos"].append(encoded_video)
ret["prompt_token_ids"].append(text_input_ids[0])
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_token_ids = batch["prompt_token_ids"]
latent = batch["encoded_videos"]
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
assert prompt_embeds.requires_grad is False
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
# Add noise to latent
latent = latent.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
noise = torch.randn_like(latent)
latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps)
# Prepare rotary embeds
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
transformer_config = self.state.transformer_config
rotary_emb = (
self.prepare_rotary_positional_embeddings(
height=height * vae_scale_factor_spatial,
width=width * vae_scale_factor_spatial,
num_frames=num_frames,
transformer_config=transformer_config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
device=self.accelerator.device,
)
if transformer_config.use_rotary_positional_embeddings
else None
)
# Predict noise
predicted_noise = self.components.transformer(
hidden_states=latent_added_noise,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
image_rotary_emb=rotary_emb,
return_dict=False,
)[0]
# Denoise
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
while len(weights.shape) < len(latent_pred.shape):
weights = weights.unsqueeze(-1)
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
return loss
@override
def validation_step(
self, eval_data: Dict[str, Any]
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
"""
Return the data that needs to be saved. For videos, the data format is List[PIL],
and for images, the data format is PIL
"""
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
pipe = self.components.pipeline_cls(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
)
video_generate = pipe(
num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
generator=self.state.generator
).frames[0]
return [("video", video_generate)]
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(grid_height, grid_width),
device=device,
)
return freqs_cos, freqs_sin
register("cogvideox-t2v", "lora", CogVideoXT2VLoraTrainer)

62
finetune/models/utils.py Normal file
View File

@ -0,0 +1,62 @@
from typing import Literal, Dict
from finetune.trainer import Trainer
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
"""Register a model and its associated functions for a specific training type.
Args:
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
trainer_cls (Trainer): Trainer class to register.
"""
# Check if model_name exists in SUPPORTED_MODELS
if model_name not in SUPPORTED_MODELS:
SUPPORTED_MODELS[model_name] = {}
else:
raise ValueError(f"Model {model_name} already exists")
# Check if training_type exists for this model
if training_type not in SUPPORTED_MODELS[model_name]:
SUPPORTED_MODELS[model_name][training_type] = {}
else:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
def show_supported_models():
"""Print all currently supported models and their training types."""
print("\nSupported Models:")
print("================")
for model_name, training_types in SUPPORTED_MODELS.items():
print(f"\n{model_name}")
print("-" * len(model_name))
for training_type in training_types:
print(f"{training_type}")
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
"""Get the trainer class for a specific model and training type."""
if model_type not in SUPPORTED_MODELS:
print(f"\nModel '{model_type}' is not supported.")
print("\nSupported models are:")
for supported_model in SUPPORTED_MODELS:
print(f"{supported_model}")
raise ValueError(f"Model '{model_type}' is not supported")
if training_type not in SUPPORTED_MODELS[model_type]:
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
print(f"\nSupported training types for '{model_type}' are:")
for supported_type in SUPPORTED_MODELS[model_type]:
print(f"{supported_type}")
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
return SUPPORTED_MODELS[model_type][training_type]

View File

@ -0,0 +1,5 @@
from .args import Args
from .state import State
from .components import Components
__all__ = ["Args", "State", "Components"]

211
finetune/schemas/args.py Normal file
View File

@ -0,0 +1,211 @@
import datetime
import argparse
from typing import Dict, Any, Literal, List, Tuple
from pydantic import BaseModel, field_validator, ValidationInfo
from pathlib import Path
class Args(BaseModel):
########## Model ##########
model_path: Path
model_name: str
model_type: Literal["i2v", "t2v"]
training_type: Literal["lora", "sft"] = "lora"
########## Output ##########
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now()))
report_to: Literal["tensorboard", "wandb", "all"] | None = None
tracker_name: str = "finetrainer-cogvideo"
########## Data ###########
data_root: Path
caption_column: Path
image_column: Path | None = None
video_column: Path
########## Training #########
resume_from_checkpoint: Path | None = None
seed: int | None = None
train_epochs: int
train_steps: int | None = None
checkpointing_steps: int = 200
checkpointing_limit: int = 10
batch_size: int
gradient_accumulation_steps: int = 1
train_resolution: Tuple[int, int, int] # shape: (frames, height, width)
#### deprecated args: video_resolution_buckets
# if use bucket for training, should not be None
# Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset
# Note2: For cogvideox, cogvideox1.5
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
# video_resolution_buckets: List[Tuple[int, int, int]] | None = None
mixed_precision: Literal["no", "fp16", "bf16"]
learning_rate: float = 2e-5
optimizer: str = "adamw"
beta1: float = 0.9
beta2: float = 0.95
beta3: float = 0.98
epsilon: float = 1e-8
weight_decay: float = 1e-4
max_grad_norm: float = 1.0
lr_scheduler: str = "constant_with_warmup"
lr_warmup_steps: int = 100
lr_num_cycles: int = 1
lr_power: float = 1.0
num_workers: int = 8
pin_memory: bool = True
gradient_checkpointing: bool = True
enable_slicing: bool = True
enable_tiling: bool = True
nccl_timeout: int = 1800
########## Lora ##########
rank: int = 128
lora_alpha: int = 64
target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"]
########## Validation ##########
do_validation: bool = False
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
validation_dir: Path | None # if set do_validation, should not be None
validation_prompts: str | None # if set do_validation, should not be None
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
gen_fps: int = 15
#### deprecated args: gen_video_resolution
# 1. If set do_validation, should not be None
# 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning
# or the resolution recommended by the model
# 3. Note: For cogvideox, cogvideox1.5
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
# gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width)
@field_validator("image_column")
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("model_type") == "i2v" and not v:
raise ValueError("image_column must be specified when using i2v model")
return v
@field_validator("validation_dir", "validation_prompts")
def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any:
values = info.data
if values.get("do_validation") and not v:
field_name = info.field_name
raise ValueError(f"{field_name} must be specified when do_validation is True")
return v
@field_validator("validation_images")
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
return v
@field_validator("validation_videos")
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v")
return v
@field_validator("validation_steps")
def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None:
values = info.data
if values.get("do_validation"):
if v is None:
raise ValueError("validation_steps must be specified when do_validation is True")
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
return v
@classmethod
def parse_args(cls):
"""Parse command line arguments and return Args instance"""
parser = argparse.ArgumentParser()
# Required arguments
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--model_type", type=str, required=True)
parser.add_argument("--training_type", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--data_root", type=str, required=True)
parser.add_argument("--caption_column", type=str, required=True)
parser.add_argument("--video_column", type=str, required=True)
parser.add_argument("--train_resolution", type=str, required=True)
parser.add_argument("--report_to", type=str, required=True)
# Training hyperparameters
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--train_epochs", type=int, default=10)
parser.add_argument("--train_steps", type=int, default=None)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--optimizer", type=str, default="adamw")
parser.add_argument("--beta1", type=float, default=0.9)
parser.add_argument("--beta2", type=float, default=0.95)
parser.add_argument("--beta3", type=float, default=0.98)
parser.add_argument("--epsilon", type=float, default=1e-8)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
# Learning rate scheduler
parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=100)
parser.add_argument("--lr_num_cycles", type=int, default=1)
parser.add_argument("--lr_power", type=float, default=1.0)
# Data loading
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--pin_memory", type=bool, default=True)
parser.add_argument("--image_column", type=str, default=None)
# Model configuration
parser.add_argument("--mixed_precision", type=str, default="no")
parser.add_argument("--gradient_checkpointing", type=bool, default=True)
parser.add_argument("--enable_slicing", type=bool, default=True)
parser.add_argument("--enable_tiling", type=bool, default=True)
parser.add_argument("--nccl_timeout", type=int, default=1800)
# LoRA parameters
parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument("--target_modules", type=str, nargs="+",
default=["to_q", "to_k", "to_v", "to_out.0"])
# Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200)
parser.add_argument("--checkpointing_limit", type=int, default=10)
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
# Validation
parser.add_argument("--do_validation", type=bool, default=False)
parser.add_argument("--validation_steps", type=int, default=None)
parser.add_argument("--validation_dir", type=str, default=None)
parser.add_argument("--validation_prompts", type=str, default=None)
parser.add_argument("--validation_images", type=str, default=None)
parser.add_argument("--validation_videos", type=str, default=None)
parser.add_argument("--gen_fps", type=int, default=15)
args = parser.parse_args()
# Convert video_resolution_buckets string to list of tuples
frames, height, width = args.train_resolution.split("x")
args.train_resolution = (int(frames), int(height), int(width))
return cls(**vars(args))

View File

@ -0,0 +1,27 @@
from typing import Any
from pydantic import BaseModel
class Components(BaseModel):
# pipeline cls
pipeline_cls: Any = None
# Tokenizers
tokenizer: Any = None
tokenizer_2: Any = None
tokenizer_3: Any = None
# Text encoders
text_encoder: Any = None
text_encoder_2: Any = None
text_encoder_3: Any = None
# Autoencoder
vae: Any = None
# Denoiser
transformer: Any = None
unet: Any = None
# Scheduler
scheduler: Any = None

26
finetune/schemas/state.py Normal file
View File

@ -0,0 +1,26 @@
import torch
from pathlib import Path
from typing import List, Dict, Any
from pydantic import BaseModel, field_validator
class State(BaseModel):
model_config = {"arbitrary_types_allowed": True}
train_frames: int
train_height: int
train_width: int
transformer_config: Dict[str, Any] = None
weight_dtype: torch.dtype = torch.float32
num_trainable_parameters: int = 0
overwrote_max_train_steps: bool = False
num_update_steps_per_epoch: int = 0
total_batch_size_count: int = 0
generator: torch.Generator | None = None
validation_prompts: List[str] = []
validation_images: List[Path | None] = []
validation_videos: List[Path | None] = []

View File

@ -0,0 +1,52 @@
import argparse
import os
from pathlib import Path
import cv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
return parser.parse_args()
args = parse_args()
# Create data/images directory if it doesn't exist
data_dir = Path(args.datadir)
image_dir = data_dir / "images"
image_dir.mkdir(exist_ok=True)
# Read videos.txt
videos_file = data_dir / "videos.txt"
with open(videos_file, "r") as f:
video_paths = [line.strip() for line in f.readlines() if line.strip()]
# Process each video file and collect image paths
image_paths = []
for video_rel_path in video_paths:
video_path = data_dir / video_rel_path
# Open video
cap = cv2.VideoCapture(str(video_path))
# Read first frame
ret, frame = cap.read()
if not ret:
print(f"Failed to read video: {video_path}")
continue
# Save frame as PNG with same name as video
image_name = f"images/{video_path.stem}.png"
image_path = data_dir / image_name
cv2.imwrite(str(image_path), frame)
# Release video capture
cap.release()
print(f"Extracted first frame from {video_path} to {image_path}")
image_paths.append(image_name)
# Write images.txt
images_file = data_dir / "images.txt"
with open(images_file, "w") as f:
for path in image_paths:
f.write(f"{path}\n")

18
finetune/train.py Normal file
View File

@ -0,0 +1,18 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from finetune.schemas import Args
from finetune.models.utils import get_model_cls
def main():
args = Args.parse_args()
trainer_cls = get_model_cls(args.model_name, args.training_type)
trainer = trainer_cls(args)
trainer.fit()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

688
finetune/trainer.py Normal file
View File

@ -0,0 +1,688 @@
import os
import logging
import math
import json
import torch
import transformers
import diffusers
import wandb
from datetime import timedelta
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Any, List, Tuple
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from accelerate.logging import get_logger
from accelerate.accelerator import Accelerator, DistributedType
from accelerate.utils import (
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
set_seed,
gather_object,
)
from diffusers.optimization import get_scheduler
from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from finetune.schemas import Args, State, Components
from finetune.utils import (
unwrap_model, cast_training_params,
get_optimizer,
get_memory_statistics,
free_memory,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
string_to_filename
)
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import (
load_prompts, load_images, load_videos,
preprocess_image_with_resize, preprocess_video_with_resize
)
from finetune.constants import LOG_NAME, LOG_LEVEL
logger = get_logger(LOG_NAME, LOG_LEVEL)
_DTYPE_MAP = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
class Trainer:
def __init__(self, args: Args) -> None:
self.args = args
self.state = State(
weight_dtype=self.__get_training_dtype(),
train_frames=self.args.train_resolution[0],
train_height=self.args.train_resolution[1],
train_width=self.args.train_resolution[2]
)
self.components = Components()
self.accelerator: Accelerator = None
self.dataset: Dataset = None
self.data_loader: DataLoader = None
self.optimizer = None
self.lr_scheduler = None
self._init_distributed()
self._init_logging()
self._init_directories()
def _init_distributed(self):
logging_dir = Path(self.args.output_dir, "logs")
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_process_group_kwargs = InitProcessGroupKwargs(
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
)
mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
accelerator = Accelerator(
project_config=project_config,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=report_to,
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
self.accelerator = accelerator
if self.args.seed is not None:
set_seed(self.args.seed)
def _init_logging(self) -> None:
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=LOG_LEVEL,
)
if self.accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
logger.info("Initialized Trainer")
logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False)
def _init_directories(self) -> None:
if self.accelerator.is_main_process:
self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True)
def prepare_models(self) -> None:
logger.info("Initializing models")
# Initialize model components
self.components = self.load_components()
if self.components.vae is not None:
if self.args.enable_slicing:
self.components.vae.enable_slicing()
if self.args.enable_tiling:
self.components.vae.enable_tiling()
self.state.transformer_config = self.components.transformer.config
def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader")
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_video_fn=self.encode_video,
max_num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width
)
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_video_fn=self.encode_video,
max_num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width
)
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
# Prepare VAE for encoding
self.components.vae = self.components.vae.to(self.accelerator.device)
self.components.vae.requires_grad_(False)
# Precompute latent for video
logger.info("Precomputing latent for video ...")
tmp_data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_size=1,
num_workers=0,
pin_memory=self.args.pin_memory,
)
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ...
logger.info("Precomputing latent for video ... Done")
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
shuffle=True
)
def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters")
# For now only lora is supported
for attr_name, component in vars(self.components).items():
if hasattr(component, 'requires_grad_'):
component.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = self.state.weight_dtype
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
self.__move_components_to_device()
if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing()
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config)
def prepare_optimizer(self) -> None:
logger.info("Initializing optimizer and lr scheduler")
# Make sure the trainable params are in float32
if self.args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
transformer_parameters_with_lr = {
"params": transformer_lora_parameters,
"lr": self.args.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr]
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)
use_deepspeed_opt = (
self.accelerator.state.deepspeed_plugin is not None
and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(
params_to_optimize=params_to_optimize,
optimizer_name=self.args.optimizer,
learning_rate=self.args.learning_rate,
beta1=self.args.beta1,
beta2=self.args.beta2,
beta3=self.args.beta3,
epsilon=self.args.epsilon,
weight_decay=self.args.weight_decay,
use_deepspeed=use_deepspeed_opt,
)
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
if self.args.train_steps is None:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
self.state.overwrote_max_train_steps = True
use_deepspeed_lr_scheduler = (
self.accelerator.state.deepspeed_plugin is not None
and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config
)
total_training_steps = self.args.train_steps * self.accelerator.num_processes
num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes
if use_deepspeed_lr_scheduler:
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=self.args.lr_scheduler,
optimizer=optimizer,
total_num_steps=total_training_steps,
num_warmup_steps=num_warmup_steps,
)
else:
lr_scheduler = get_scheduler(
name=self.args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_training_steps,
num_cycles=self.args.lr_num_cycles,
power=self.args.lr_power,
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def prepare_for_training(self) -> None:
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare(
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
if self.state.overwrote_max_train_steps:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
def prepare_for_validation(self):
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
if self.args.validation_images is not None:
validation_images = load_images(self.args.validation_dir / self.args.validation_images)
else:
validation_images = [None] * len(validation_prompts)
if self.args.validation_videos is not None:
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos)
else:
validation_videos = [None] * len(validation_prompts)
self.state.validation_prompts = validation_prompts
self.state.validation_images = validation_images
self.state.validation_videos = validation_videos
def prepare_trackers(self) -> None:
logger.info("Initializing trackers")
tracker_name = self.args.tracker_name or "finetrainers-experiment"
self.accelerator.init_trackers(tracker_name, config=self.args.model_dump())
def train(self) -> None:
logger.info("Starting training")
memory_statistics = get_memory_statistics()
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
self.state.total_batch_size_count = (
self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps
)
info = {
"trainable parameters": self.state.num_trainable_parameters,
"total samples": len(self.dataset),
"train epochs": self.args.train_epochs,
"train steps": self.args.train_steps,
"batches per device": self.args.batch_size,
"total batches observed per epoch": len(self.data_loader),
"train batch size total count": self.state.total_batch_size_count,
"gradient accumulation steps": self.args.gradient_accumulation_steps,
}
logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
global_step = 0
first_epoch = 0
initial_global_step = 0
# Potentially load in the weights and states from a previous save
(
resume_from_checkpoint_path,
initial_global_step,
global_step,
first_epoch,
) = get_latest_ckpt_path_to_resume_from(
resume_from_checkpoint=self.args.resume_from_checkpoint,
num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
)
if resume_from_checkpoint_path is not None:
self.accelerator.load_state(resume_from_checkpoint_path)
progress_bar = tqdm(
range(0, self.args.train_steps),
initial=initial_global_step,
desc="Training steps",
disable=not self.accelerator.is_local_main_process,
)
accelerator = self.accelerator
generator = torch.Generator(device=accelerator.device)
if self.args.seed is not None:
generator = generator.manual_seed(self.args.seed)
self.state.generator = generator
for epoch in range(first_epoch, self.args.train_epochs):
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
self.components.transformer.train()
models_to_accumulate = [self.components.transformer]
for step, batch in enumerate(self.data_loader):
logger.debug(f"Starting step {step + 1}")
logs = {}
with accelerator.accumulate(models_to_accumulate):
# These weighting schemes use a uniform timestep sampling and instead post-weight the loss
loss = self.compute_loss(batch)
accelerator.backward(loss)
if accelerator.sync_gradients:
if accelerator.distributed_type == DistributedType.DEEPSPEED:
grad_norm = self.components.transformer.get_global_grad_norm()
# In some cases the grad norm may not return a float
if torch.is_tensor(grad_norm):
grad_norm = grad_norm.item()
else:
grad_norm = accelerator.clip_grad_norm_(
self.components.transformer.parameters(), self.args.max_grad_norm
)
if torch.is_tensor(grad_norm):
grad_norm = grad_norm.item()
logs["grad_norm"] = grad_norm
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
self.__maybe_save_checkpoint(global_step)
# Maybe run validation
should_run_validation = (
self.args.do_validation
and global_step % self.args.validation_steps == 0
)
if should_run_validation:
self.validate(global_step)
logs["loss"] = loss.detach().item()
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
accelerator.log(logs, step=global_step)
if global_step >= self.args.train_steps:
break
memory_statistics = get_memory_statistics()
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True)
if self.args.do_validation:
self.validate(global_step)
del self.components
free_memory()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
accelerator.end_training()
def validate(self, step: int) -> None:
logger.info("Starting validation")
accelerator = self.accelerator
num_validation_samples = len(self.state.validation_prompts)
if num_validation_samples == 0:
logger.warning("No validation samples found. Skipping validation.")
return
self.components.transformer.eval()
memory_statistics = get_memory_statistics()
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
all_processes_artifacts = []
for i in range(num_validation_samples):
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue
prompt = self.state.validation_prompts[i]
image = self.state.validation_images[i]
video = self.state.validation_videos[i]
if image is not None:
image = preprocess_image_with_resize(
image, self.state.train_height, self.state.train_width
)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
image = Image.fromarray(image)
if video is not None:
video = preprocess_video_with_resize(
video, self.state.train_frames, self.state.train_height, self.state.train_width
)
# Convert video tensor (F, C, H, W) to list of PIL images
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.validation_step({
"prompt": prompt,
"image": image,
"video": video
})
prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
"video": {"type": "video", "value": video},
}
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
logger.debug(
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
main_process_only=False,
)
for key, value in list(artifacts.items()):
artifact_type = value["type"]
artifact_value = value["value"]
if artifact_type not in ["image", "video"] or artifact_value is None:
continue
extension = "png" if artifact_type == "image" else "mp4"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
validation_path = self.args.output_dir / "validation_res"
validation_path.mkdir(parents=True, exist_ok=True)
filename = str(validation_path / filename)
if artifact_type == "image":
logger.debug(f"Saving image to {filename}")
artifact_value.save(filename)
artifact_value = wandb.Image(filename)
elif artifact_type == "video":
logger.debug(f"Saving video to {filename}")
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
artifact_value = wandb.Video(filename, caption=prompt)
all_processes_artifacts.append(artifact_value)
all_artifacts = gather_object(all_processes_artifacts)
if accelerator.is_main_process:
tracker_key = "validation"
for tracker in accelerator.trackers:
if tracker.name == "wandb":
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
tracker.log(
{
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
},
step=step,
)
accelerator.wait_for_everyone()
free_memory()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
self.components.transformer.train()
def fit(self):
self.prepare_models()
self.prepare_dataset()
self.prepare_trainable_parameters()
self.prepare_optimizer()
self.prepare_for_training()
if self.args.do_validation:
self.prepare_for_validation()
self.prepare_trackers()
self.train()
def collate_fn(self, examples: List[Dict[str, Any]]):
raise NotImplementedError
def load_components(self) -> Components:
raise NotImplementedError
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W], where B = 1
raise NotImplementedError
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
raise NotImplementedError
def __get_training_dtype(self) -> torch.dtype:
if self.args.mixed_precision == "no":
return _DTYPE_MAP["fp32"]
elif self.args.mixed_precision == "fp16":
return _DTYPE_MAP["fp16"]
elif self.args.mixed_precision == "bf16":
return _DTYPE_MAP["bf16"]
else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
setattr(self.components, name, component.to(self.accelerator.device))
def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if self.accelerator.is_main_process:
transformer_lora_layers_to_save = None
for model in models:
if isinstance(
unwrap_model(self.accelerator, model),
type(unwrap_model(self.accelerator, self.components.transformer)),
):
model = unwrap_model(self.accelerator, model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"Unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
self.components.pipeline_cls.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
if not self.accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(
unwrap_model(self.accelerator, model),
type(unwrap_model(self.accelerator, self.components.transformer)),
):
transformer_ = unwrap_model(self.accelerator, model)
else:
raise ValueError(
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
)
else:
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
self.args.model_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("transformer.")
}
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if self.args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([transformer_])
self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook)
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
if must_save or global_step % self.args.checkpointing_steps == 0:
save_path = get_intermediate_ckpt_path(
checkpointing_limit=self.args.checkpointing_limit,
step=global_step,
output_dir=self.args.output_dir,
)
self.accelerator.save_state(save_path)

View File

@ -0,0 +1,5 @@
from .torch_utils import *
from .optimizer_utils import *
from .memory_utils import *
from .checkpointing import *
from .file_utils import *

View File

@ -0,0 +1,53 @@
import os
from pathlib import Path
from typing import Tuple
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from ..utils.file_utils import find_files, delete_files
logger = get_logger(LOG_NAME, LOG_LEVEL)
def get_latest_ckpt_path_to_resume_from(
resume_from_checkpoint: str | None, num_update_steps_per_epoch: int
) -> Tuple[str | None, int, int, int]:
if resume_from_checkpoint is None:
initial_global_step = 0
global_step = 0
first_epoch = 0
resume_from_checkpoint_path = None
else:
resume_from_checkpoint_path = Path(resume_from_checkpoint)
if not resume_from_checkpoint_path.exists():
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
initial_global_step = 0
global_step = 0
first_epoch = 0
resume_from_checkpoint_path = None
else:
logger.info(f"Resuming from checkpoint {resume_from_checkpoint}")
global_step = int(resume_from_checkpoint_path.name.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
# before saving state, check if this save would set us over the `checkpointing_limit`
if checkpointing_limit is not None:
checkpoints = find_files(output_dir, prefix="checkpoint")
# before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= checkpointing_limit:
num_to_remove = len(checkpoints) - checkpointing_limit + 1
checkpoints_to_remove = checkpoints[0:num_to_remove]
delete_files(checkpoints_to_remove)
logger.info(f"Checkpointing at step {step}")
save_path = os.path.join(output_dir, f"checkpoint-{step}")
logger.info(f"Saving state to {save_path}")
return save_path

View File

@ -0,0 +1,47 @@
import logging
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Union
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
logger = get_logger(LOG_NAME, LOG_LEVEL)
def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
if not isinstance(dir, Path):
dir = Path(dir)
if not dir.exists():
return []
checkpoints = os.listdir(dir.as_posix())
checkpoints = [c for c in checkpoints if c.startswith(prefix)]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
checkpoints = [dir / c for c in checkpoints]
return checkpoints
def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
if not isinstance(dirs, list):
dirs = [dirs]
dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
logger.info(f"Deleting files: {dirs}")
for dir in dirs:
if not dir.exists():
continue
shutil.rmtree(dir, ignore_errors=True)
def string_to_filename(s: str) -> str:
return (
s.replace(" ", "-")
.replace("/", "-")
.replace(":", "-")
.replace(".", "-")
.replace(",", "-")
.replace(";", "-")
.replace("!", "-")
.replace("?", "-")
)

View File

@ -0,0 +1,60 @@
import gc
import torch
from typing import Any, Dict, Union
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
logger = get_logger(LOG_NAME, LOG_LEVEL)
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
memory_allocated = None
memory_reserved = None
max_memory_allocated = None
max_memory_reserved = None
if torch.cuda.is_available():
device = torch.cuda.current_device()
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
max_memory_allocated = torch.cuda.max_memory_allocated(device)
max_memory_reserved = torch.cuda.max_memory_reserved(device)
elif torch.mps.is_available():
memory_allocated = torch.mps.current_allocated_memory()
else:
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
return {
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
}
def bytes_to_gigabytes(x: int) -> float:
if x is not None:
return x / 1024**3
def free_memory() -> None:
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# TODO(aryan): handle non-cuda devices
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if isinstance(x, torch.Tensor):
return x.contiguous()
elif isinstance(x, dict):
return {k: make_contiguous(v) for k, v in x.items()}
else:
return x

View File

@ -0,0 +1,180 @@
import inspect
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
logger = get_logger(LOG_NAME, LOG_LEVEL)
def get_optimizer(
params_to_optimize,
optimizer_name: str = "adam",
learning_rate: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.95,
beta3: float = 0.98,
epsilon: float = 1e-8,
weight_decay: float = 1e-4,
prodigy_decouple: bool = False,
prodigy_use_bias_correction: bool = False,
prodigy_safeguard_warmup: bool = False,
use_8bit: bool = False,
use_4bit: bool = False,
use_torchao: bool = False,
use_deepspeed: bool = False,
use_cpu_offload_optimizer: bool = False,
offload_gradients: bool = False,
) -> torch.optim.Optimizer:
optimizer_name = optimizer_name.lower()
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=learning_rate,
betas=(beta1, beta2),
eps=epsilon,
weight_decay=weight_decay,
)
if use_8bit and use_4bit:
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
try:
import torchao
torchao.__version__
except ImportError:
raise ImportError(
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
)
if not use_torchao and use_4bit:
raise ValueError("4-bit Optimizers are only supported with torchao.")
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
if optimizer_name not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
)
optimizer_name = "adamw"
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
if use_8bit:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if optimizer_name == "adamw":
if use_torchao:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
else:
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "adam":
if use_torchao:
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
else:
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
init_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"beta3": beta3,
"eps": epsilon,
"weight_decay": weight_decay,
"decouple": prodigy_decouple,
"use_bias_correction": prodigy_use_bias_correction,
"safeguard_warmup": prodigy_safeguard_warmup,
}
elif optimizer_name == "came":
try:
import came_pytorch
except ImportError:
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
optimizer_class = came_pytorch.CAME
init_kwargs = {
"lr": learning_rate,
"eps": (1e-30, 1e-16),
"betas": (beta1, beta2, beta3),
"weight_decay": weight_decay,
}
if use_cpu_offload_optimizer:
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
init_kwargs.update({"fused": True})
optimizer = CPUOffloadOptimizer(
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
)
else:
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
return optimizer
def gradient_norm(parameters):
norm = 0
for param in parameters:
if param.grad is None:
continue
local_norm = param.grad.detach().data.norm(2)
norm += local_norm.item() ** 2
norm = norm**0.5
return norm
def max_gradient(parameters):
max_grad_value = float("-inf")
for param in parameters:
if param.grad is None:
continue
local_max_grad = param.grad.detach().data.abs().max()
max_grad_value = max(max_grad_value, local_max_grad.item())
return max_grad_value

View File

@ -0,0 +1,52 @@
from typing import Dict, Optional, Union, List
import torch
from accelerate import Accelerator
from diffusers.utils.torch_utils import is_compiled_module
def unwrap_model(accelerator: Accelerator, model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def align_device_and_dtype(
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if isinstance(x, torch.Tensor):
if device is not None:
x = x.to(device)
if dtype is not None:
x = x.to(dtype)
elif isinstance(x, dict):
if device is not None:
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
if dtype is not None:
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
return x
def expand_tensor_to_dims(tensor, ndim):
while len(tensor.shape) < ndim:
tensor = tensor.unsqueeze(-1)
return tensor
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
"""
Casts the training parameters of the model to the specified data type.
Args:
model: The PyTorch model whose parameters will be cast.
dtype: The data type to which the model parameters will be cast.
"""
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)