diff --git a/.gitignore b/.gitignore
index 6be6f4b..2ee7d3d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,4 +8,17 @@ logs/
.idea
output*
test*
-venv
\ No newline at end of file
+venv
+**/.swp
+**/*.log
+**/*.debug
+**/.vscode
+
+**/*debug*
+**/.gitignore
+**/finetune/*-lora-*
+**/finetune/Disney-*
+**/wandb
+**/results
+**/*.mp4
+**/validation_set
diff --git a/finetune/accelerate_config_machine_multi.yaml b/finetune/accelerate_config_machine_multi.yaml
deleted file mode 100644
index 856db57..0000000
--- a/finetune/accelerate_config_machine_multi.yaml
+++ /dev/null
@@ -1,26 +0,0 @@
-compute_environment: LOCAL_MACHINE
-debug: true
-deepspeed_config:
- deepspeed_hostfile: hostfile.txt
- deepspeed_multinode_launcher: pdsh
- gradient_accumulation_steps: 1
- gradient_clipping: 1.0
- offload_optimizer_device: none
- offload_param_device: none
- zero3_init_flag: true
- zero_stage: 3
-distributed_type: DEEPSPEED
-downcast_bf16: 'yes'
-enable_cpu_affinity: true
-main_process_ip: 10.250.128.19
-main_process_port: 12355
-main_training_function: main
-mixed_precision: bf16
-num_machines: 4
-num_processes: 32
-rdzv_backend: static
-same_network: true
-tpu_env: []
-tpu_use_cluster: false
-tpu_use_sudo: false
-use_cpu: false
diff --git a/finetune/accelerate_config_machine_single.yaml b/finetune/accelerate_config_machine_single.yaml
deleted file mode 100644
index 90d993a..0000000
--- a/finetune/accelerate_config_machine_single.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-compute_environment: LOCAL_MACHINE
-gpu_ids: "0"
-debug: false
-deepspeed_config:
- deepspeed_config_file: ds_config.json
- zero3_init_flag: false
-distributed_type: DEEPSPEED
-downcast_bf16: 'no'
-enable_cpu_affinity: false
-machine_rank: 0
-main_training_function: main
-dynamo_backend: 'no'
-num_machines: 1
-num_processes: 1
-rdzv_backend: static
-same_network: true
-tpu_env: []
-tpu_use_cluster: false
-tpu_use_sudo: false
-use_cpu: false
\ No newline at end of file
diff --git a/finetune/accelerate_train_i2v.sh b/finetune/accelerate_train_i2v.sh
new file mode 100644
index 0000000..ec3922e
--- /dev/null
+++ b/finetune/accelerate_train_i2v.sh
@@ -0,0 +1,69 @@
+#!/usr/bin/env bash
+
+# Prevent tokenizer parallelism issues
+export TOKENIZERS_PARALLELISM=false
+
+# Model Configuration
+MODEL_ARGS=(
+ --model_path "THUDM/CogVideoX1.5-5B-I2V"
+ --model_name "cogvideox1.5-i2v"
+ --model_type "i2v"
+ --training_type "lora"
+)
+
+# Output Configuration
+OUTPUT_ARGS=(
+ --output_dir "/path/to/output/dir"
+ --report_to "tensorboard"
+)
+
+# Data Configuration
+DATA_ARGS=(
+ --data_root "/path/to/data/dir"
+ --caption_column "prompt.txt"
+ --video_column "videos.txt"
+ --image_column "images.txt"
+ --train_resolution "80x768x1360"
+)
+
+# Training Configuration
+TRAIN_ARGS=(
+ --train_epochs 10
+ --batch_size 1
+ --gradient_accumulation_steps 1
+ --mixed_precision "bf16"
+ --seed 42
+)
+
+# System Configuration
+SYSTEM_ARGS=(
+ --num_workers 8
+ --pin_memory True
+ --nccl_timeout 1800
+)
+
+# Checkpointing Configuration
+CHECKPOINT_ARGS=(
+ --checkpointing_steps 200
+ --checkpointing_limit 10
+)
+
+# Validation Configuration
+VALIDATION_ARGS=(
+ --do_validation False
+ --validation_dir "/path/to/validation/dir"
+ --validation_steps 400
+ --validation_prompts "prompts.txt"
+ --validation_images "images.txt"
+ --gen_fps 15
+)
+
+# Combine all arguments and launch training
+accelerate launch train.py \
+ "${MODEL_ARGS[@]}" \
+ "${OUTPUT_ARGS[@]}" \
+ "${DATA_ARGS[@]}" \
+ "${TRAIN_ARGS[@]}" \
+ "${SYSTEM_ARGS[@]}" \
+ "${CHECKPOINT_ARGS[@]}" \
+ "${VALIDATION_ARGS[@]}"
\ No newline at end of file
diff --git a/finetune/accelerate_train_t2v.sh b/finetune/accelerate_train_t2v.sh
new file mode 100644
index 0000000..0d2b7f6
--- /dev/null
+++ b/finetune/accelerate_train_t2v.sh
@@ -0,0 +1,67 @@
+#!/usr/bin/env bash
+
+# Prevent tokenizer parallelism issues
+export TOKENIZERS_PARALLELISM=false
+
+# Model Configuration
+MODEL_ARGS=(
+ --model_path "THUDM/CogVideoX1.5-5B"
+ --model_name "cogvideox1.5-t2v"
+ --model_type "t2v"
+ --training_type "lora"
+)
+
+# Output Configuration
+OUTPUT_ARGS=(
+ --output_dir "/path/to/output/dir"
+ --report_to "tensorboard"
+)
+
+# Data Configuration
+DATA_ARGS=(
+ --data_root "/path/to/data/dir"
+ --caption_column "prompt.txt"
+ --video_column "videos.txt"
+ --train_resolution "80x768x1360"
+)
+
+# Training Configuration
+TRAIN_ARGS=(
+ --train_epochs 10
+ --batch_size 1
+ --gradient_accumulation_steps 1
+ --mixed_precision "bf16"
+ --seed 42
+)
+
+# System Configuration
+SYSTEM_ARGS=(
+ --num_workers 8
+ --pin_memory True
+ --nccl_timeout 1800
+)
+
+# Checkpointing Configuration
+CHECKPOINT_ARGS=(
+ --checkpointing_steps 200
+ --checkpointing_limit 10
+)
+
+# Validation Configuration
+VALIDATION_ARGS=(
+ --do_validation False
+ --validation_dir "/path/to/validation/dir"
+ --validation_steps 400
+ --validation_prompts "prompts.txt"
+ --gen_fps 15
+)
+
+# Combine all arguments and launch training
+accelerate launch train.py \
+ "${MODEL_ARGS[@]}" \
+ "${OUTPUT_ARGS[@]}" \
+ "${DATA_ARGS[@]}" \
+ "${TRAIN_ARGS[@]}" \
+ "${SYSTEM_ARGS[@]}" \
+ "${CHECKPOINT_ARGS[@]}" \
+ "${VALIDATION_ARGS[@]}"
\ No newline at end of file
diff --git a/finetune/constants.py b/finetune/constants.py
new file mode 100644
index 0000000..30ef0eb
--- /dev/null
+++ b/finetune/constants.py
@@ -0,0 +1,2 @@
+LOG_NAME = "trainer"
+LOG_LEVEL = "INFO"
\ No newline at end of file
diff --git a/finetune/datasets/__init__.py b/finetune/datasets/__init__.py
new file mode 100644
index 0000000..8c9b61a
--- /dev/null
+++ b/finetune/datasets/__init__.py
@@ -0,0 +1,12 @@
+from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
+from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
+from .bucket_sampler import BucketSampler
+
+
+__all__ = [
+ "I2VDatasetWithResize",
+ "I2VDatasetWithBuckets",
+ "T2VDatasetWithResize",
+ "T2VDatasetWithBuckets",
+ "BucketSampler"
+]
diff --git a/finetune/datasets/bucket_sampler.py b/finetune/datasets/bucket_sampler.py
new file mode 100644
index 0000000..bf1beb1
--- /dev/null
+++ b/finetune/datasets/bucket_sampler.py
@@ -0,0 +1,73 @@
+import random
+import logging
+
+from torch.utils.data import Sampler
+from torch.utils.data import Dataset
+
+logger = logging.getLogger(__name__)
+
+
+class BucketSampler(Sampler):
+ r"""
+ PyTorch Sampler that groups 3D data by height, width and frames.
+
+ Args:
+ data_source (`VideoDataset`):
+ A PyTorch dataset object that is an instance of `VideoDataset`.
+ batch_size (`int`, defaults to `8`):
+ The batch size to use for training.
+ shuffle (`bool`, defaults to `True`):
+ Whether or not to shuffle the data in each batch before dispatching to dataloader.
+ drop_last (`bool`, defaults to `False`):
+ Whether or not to drop incomplete buckets of data after completely iterating over all data
+ in the dataset. If set to True, only batches that have `batch_size` number of entries will
+ be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
+ and batches that do not have `batch_size` number of entries will also be yielded.
+ """
+
+ def __init__(
+ self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
+ ) -> None:
+ self.data_source = data_source
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.drop_last = drop_last
+
+ self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets}
+
+ self._raised_warning_for_drop_last = False
+
+
+ def __len__(self):
+ if self.drop_last and not self._raised_warning_for_drop_last:
+ self._raised_warning_for_drop_last = True
+ logger.warning(
+ "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
+ )
+ return (len(self.data_source) + self.batch_size - 1) // self.batch_size
+
+
+ def __iter__(self):
+ for index, data in enumerate(self.data_source):
+ video_metadata = data["video_metadata"]
+ f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
+
+ self.buckets[(f, h, w)].append(data)
+ if len(self.buckets[(f, h, w)]) == self.batch_size:
+ if self.shuffle:
+ random.shuffle(self.buckets[(f, h, w)])
+ yield self.buckets[(f, h, w)]
+ del self.buckets[(f, h, w)]
+ self.buckets[(f, h, w)] = []
+
+ if self.drop_last:
+ return
+
+ for fhw, bucket in list(self.buckets.items()):
+ if len(bucket) == 0:
+ continue
+ if self.shuffle:
+ random.shuffle(bucket)
+ yield bucket
+ del self.buckets[fhw]
+ self.buckets[fhw] = []
diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py
new file mode 100644
index 0000000..e993c96
--- /dev/null
+++ b/finetune/datasets/i2v_dataset.py
@@ -0,0 +1,275 @@
+import torch
+
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Callable
+from typing_extensions import override
+
+from accelerate.logging import get_logger
+from torch.utils.data import Dataset
+from torchvision import transforms
+from finetune.constants import LOG_NAME, LOG_LEVEL
+
+from .utils import (
+ load_prompts, load_videos, load_images,
+
+ preprocess_image_with_resize,
+ preprocess_video_with_resize,
+ preprocess_video_with_buckets
+)
+
+# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
+# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
+import decord # isort:skip
+
+decord.bridge.set_bridge("torch")
+
+logger = get_logger(LOG_NAME, LOG_LEVEL)
+
+
+class BaseI2VDataset(Dataset):
+ """
+ Base dataset class for Image-to-Video (I2V) training.
+
+ This dataset loads prompts, videos and corresponding conditioning images for I2V training.
+
+ Args:
+ data_root (str): Root directory containing the dataset files
+ caption_column (str): Path to file containing text prompts/captions
+ video_column (str): Path to file containing video paths
+ image_column (str): Path to file containing image paths
+ device (torch.device): Device to load the data on
+ encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
+ """
+ def __init__(
+ self,
+ data_root: str,
+ caption_column: str,
+ video_column: str,
+ image_column: str,
+ device: torch.device,
+ encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+
+ data_root = Path(data_root)
+ self.prompts = load_prompts(data_root / caption_column)
+ self.videos = load_videos(data_root / video_column)
+ self.images = load_images(data_root / image_column)
+
+ self.device = device
+ self.encode_video_fn = encode_video_fn
+
+ # Check if number of prompts matches number of videos and images
+ if not (len(self.videos) == len(self.prompts) == len(self.images)):
+ raise ValueError(
+ f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
+ )
+
+ # Check if all video files exist
+ if any(not path.is_file() for path in self.videos):
+ raise ValueError(
+ f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
+ )
+
+ # Check if all image files exist
+ if any(not path.is_file() for path in self.images):
+ raise ValueError(
+ f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
+ )
+
+ def __len__(self) -> int:
+ return len(self.videos)
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ if isinstance(index, list):
+ # Here, index is actually a list of data objects that we need to return.
+ # The BucketSampler should ideally return indices. But, in the sampler, we'd like
+ # to have information about num_frames, height and width. Since this is not stored
+ # as metadata, we need to read the video to get this information. You could read this
+ # information without loading the full video in memory, but we do it anyway. In order
+ # to not load the video twice (once to get the metadata, and once to return the loaded video
+ # based on sampled indices), we cache it in the BucketSampler. When the sampler is
+ # to yield, we yield the cache data instead of indices. So, this special check ensures
+ # that data is not loaded a second time. PRs are welcome for improvements.
+ return index
+
+ prompt = self.prompts[index]
+ video = self.videos[index]
+ image = self.images[index]
+
+ video_latent_dir = video.parent / "latent"
+ video_latent_dir.mkdir(parents=True, exist_ok=True)
+ encoded_video_path = video_latent_dir / (video.stem + ".pt")
+
+ if encoded_video_path.exists():
+ encoded_video = torch.load(encoded_video_path, weights_only=True)
+ # shape of image: [C, H, W]
+ _, image = self.preprocess(None, self.images[index])
+ else:
+ frames, image = self.preprocess(video, image)
+ frames = frames.to(self.device)
+ # current shape of frames: [F, C, H, W]
+ frames = self.video_transform(frames)
+ # Convert to [B, C, F, H, W]
+ frames = frames.unsqueeze(0)
+ frames = frames.permute(0, 2, 1, 3, 4).contiguous()
+ encoded_video = self.encode_video_fn(frames)
+ # [B, C, F, H, W] -> [C, F, H, W]
+ encoded_video = encoded_video[0].cpu()
+ torch.save(encoded_video, encoded_video_path)
+ logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
+
+ # shape of encoded_video: [C, F, H, W]
+ # shape of image: [C, H, W]
+ return {
+ "prompt": prompt,
+ "image": image,
+ "encoded_video": encoded_video,
+ "video_metadata": {
+ "num_frames": encoded_video.shape[1],
+ "height": encoded_video.shape[2],
+ "width": encoded_video.shape[3],
+ },
+ }
+
+ def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Loads and preprocesses a video and an image.
+ If either path is None, no preprocessing will be done for that input.
+
+ Args:
+ video_path: Path to the video file to load
+ image_path: Path to the image file to load
+
+ Returns:
+ A tuple containing:
+ - video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
+ C is number of channels, H is height and W is width
+ - image(torch.Tensor) of shape [C, H, W]
+ """
+ raise NotImplementedError("Subclass must implement this method")
+
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
+ """
+ Applies transformations to a video.
+
+ Args:
+ frames (torch.Tensor): A 4D tensor representing a video
+ with shape [F, C, H, W] where:
+ - F is number of frames
+ - C is number of channels (3 for RGB)
+ - H is height
+ - W is width
+
+ Returns:
+ torch.Tensor: The transformed video tensor
+ """
+ raise NotImplementedError("Subclass must implement this method")
+
+ def image_transform(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Applies transformations to an image.
+
+ Args:
+ image (torch.Tensor): A 3D tensor representing an image
+ with shape [C, H, W] where:
+ - C is number of channels (3 for RGB)
+ - H is height
+ - W is width
+
+ Returns:
+ torch.Tensor: The transformed image tensor
+ """
+ raise NotImplementedError("Subclass must implement this method")
+
+
+class I2VDatasetWithResize(BaseI2VDataset):
+ """
+ A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
+
+ This class preprocesses videos and images by resizing them to specified dimensions:
+ - Videos are resized to max_num_frames x height x width
+ - Images are resized to height x width
+
+ Args:
+ max_num_frames (int): Maximum number of frames to extract from videos
+ height (int): Target height for resizing videos and images
+ width (int): Target width for resizing videos and images
+ """
+ def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ self.max_num_frames = max_num_frames
+ self.height = height
+ self.width = width
+
+ self.__frame_transforms = transforms.Compose(
+ [
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
+ ]
+ )
+ self.__image_transforms = self.__frame_transforms
+
+ @override
+ def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
+ if video_path is not None:
+ video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
+ else:
+ video = None
+ if image_path is not None:
+ image = preprocess_image_with_resize(image_path, self.height, self.width)
+ else:
+ image = None
+ return video, image
+
+ @override
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
+ return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
+
+ @override
+ def image_transform(self, image: torch.Tensor) -> torch.Tensor:
+ return self.__image_transforms(image)
+
+
+class I2VDatasetWithBuckets(BaseI2VDataset):
+
+ def __init__(
+ self,
+ video_resolution_buckets: List[Tuple[int, int, int]],
+ vae_temporal_compression_ratio: int,
+ vae_height_compression_ratio: int,
+ vae_width_compression_ratio: int,
+ *args, **kwargs
+ ) -> None:
+ super().__init__(*args, **kwargs)
+
+ self.video_resolution_buckets = [
+ (
+ int(b[0] / vae_temporal_compression_ratio),
+ int(b[1] / vae_height_compression_ratio),
+ int(b[2] / vae_width_compression_ratio),
+ )
+ for b in video_resolution_buckets
+ ]
+ self.__frame_transforms = transforms.Compose(
+ [
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
+ ]
+ )
+ self.__image_transforms = self.__frame_transforms
+
+ @override
+ def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
+ video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
+ image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
+ return video, image
+
+ @override
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
+ return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
+
+ @override
+ def image_transform(self, image: torch.Tensor) -> torch.Tensor:
+ return self.__image_transforms(image)
diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py
new file mode 100644
index 0000000..9afe53a
--- /dev/null
+++ b/finetune/datasets/t2v_dataset.py
@@ -0,0 +1,231 @@
+import torch
+
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Callable
+from typing_extensions import override
+
+from accelerate.logging import get_logger
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from finetune.constants import LOG_NAME, LOG_LEVEL
+
+from .utils import (
+ load_prompts, load_videos,
+ preprocess_video_with_resize,
+ preprocess_video_with_buckets
+)
+
+# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
+# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
+import decord # isort:skip
+
+decord.bridge.set_bridge("torch")
+
+logger = get_logger(LOG_NAME, LOG_LEVEL)
+
+
+class BaseT2VDataset(Dataset):
+ """
+ Base dataset class for Text-to-Video (T2V) training.
+
+ This dataset loads prompts and videos for T2V training.
+
+ Args:
+ data_root (str): Root directory containing the dataset files
+ caption_column (str): Path to file containing text prompts/captions
+ video_column (str): Path to file containing video paths
+ device (torch.device): Device to load the data on
+ encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
+ """
+
+ def __init__(
+ self,
+ data_root: str,
+ caption_column: str,
+ video_column: str,
+ device: torch.device = None,
+ encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+
+ data_root = Path(data_root)
+ self.prompts = load_prompts(data_root / caption_column)
+ self.videos = load_videos(data_root / video_column)
+ self.device = device
+ self.encode_video_fn = encode_video_fn
+
+ # Check if all video files exist
+ if any(not path.is_file() for path in self.videos):
+ raise ValueError(
+ f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
+ )
+
+ # Check if number of prompts matches number of videos
+ if len(self.videos) != len(self.prompts):
+ raise ValueError(
+ f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
+ )
+
+ def __len__(self) -> int:
+ return len(self.videos)
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ if isinstance(index, list):
+ # Here, index is actually a list of data objects that we need to return.
+ # The BucketSampler should ideally return indices. But, in the sampler, we'd like
+ # to have information about num_frames, height and width. Since this is not stored
+ # as metadata, we need to read the video to get this information. You could read this
+ # information without loading the full video in memory, but we do it anyway. In order
+ # to not load the video twice (once to get the metadata, and once to return the loaded video
+ # based on sampled indices), we cache it in the BucketSampler. When the sampler is
+ # to yield, we yield the cache data instead of indices. So, this special check ensures
+ # that data is not loaded a second time. PRs are welcome for improvements.
+ return index
+
+ prompt = self.prompts[index]
+ video = self.videos[index]
+
+ latent_dir = video.parent / "latent"
+ latent_dir.mkdir(parents=True, exist_ok=True)
+ encoded_video_path = latent_dir / (video.stem + ".pt")
+
+ if encoded_video_path.exists():
+ # shape of encoded_video: [C, F, H, W]
+ encoded_video = torch.load(encoded_video_path, weights_only=True)
+ else:
+ frames = self.preprocess(video)
+ frames = frames.to(self.device)
+ # current shape of frames: [F, C, H, W]
+ frames = self.video_transform(frames)
+ # Convert to [B, C, F, H, W]
+ frames = frames.unsqueeze(0)
+ frames = frames.permute(0, 2, 1, 3, 4).contiguous()
+ encoded_video = self.encode_video_fn(frames)
+ # [B, C, F, H, W] -> [C, F, H, W]
+ encoded_video = encoded_video[0].cpu()
+ torch.save(encoded_video, encoded_video_path)
+ logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
+
+ return {
+ "prompt": prompt,
+ "encoded_video": encoded_video,
+ "video_metadata": {
+ "num_frames": encoded_video.shape[1],
+ "height": encoded_video.shape[2],
+ "width": encoded_video.shape[3],
+ },
+ }
+
+ def preprocess(self, video_path: Path) -> torch.Tensor:
+ """
+ Loads and preprocesses a video.
+
+ Args:
+ video_path: Path to the video file to load.
+
+ Returns:
+ torch.Tensor: Video tensor of shape [F, C, H, W] where:
+ - F is number of frames
+ - C is number of channels (3 for RGB)
+ - H is height
+ - W is width
+ """
+ raise NotImplementedError("Subclass must implement this method")
+
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
+ """
+ Applies transformations to a video.
+
+ Args:
+ frames (torch.Tensor): A 4D tensor representing a video
+ with shape [F, C, H, W] where:
+ - F is number of frames
+ - C is number of channels (3 for RGB)
+ - H is height
+ - W is width
+
+ Returns:
+ torch.Tensor: The transformed video tensor with the same shape as the input
+ """
+ raise NotImplementedError("Subclass must implement this method")
+
+
+class T2VDatasetWithResize(BaseT2VDataset):
+ """
+ A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
+
+ This class preprocesses videos by resizing them to specified dimensions:
+ - Videos are resized to max_num_frames x height x width
+
+ Args:
+ max_num_frames (int): Maximum number of frames to extract from videos
+ height (int): Target height for resizing videos
+ width (int): Target width for resizing videos
+ """
+
+ def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ self.max_num_frames = max_num_frames
+ self.height = height
+ self.width = width
+
+ self.__frame_transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
+ ]
+ )
+
+ @override
+ def preprocess(self, video_path: Path) -> torch.Tensor:
+ return preprocess_video_with_resize(
+ video_path, self.max_num_frames, self.height, self.width,
+ )
+
+ @override
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
+ return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
+
+
+class T2VDatasetWithBuckets(BaseT2VDataset):
+
+ def __init__(
+ self,
+ video_resolution_buckets: List[Tuple[int, int, int]],
+ vae_temporal_compression_ratio: int,
+ vae_height_compression_ratio: int,
+ vae_width_compression_ratio: int,
+ *args, **kwargs
+ ) -> None:
+ """
+
+ """
+ super().__init__(*args, **kwargs)
+
+ self.video_resolution_buckets = [
+ (
+ int(b[0] / vae_temporal_compression_ratio),
+ int(b[1] / vae_height_compression_ratio),
+ int(b[2] / vae_width_compression_ratio),
+ )
+ for b in video_resolution_buckets
+ ]
+
+ self.__frame_transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
+ ]
+ )
+
+ @override
+ def preprocess(self, video_path: Path) -> torch.Tensor:
+ return preprocess_video_with_buckets(
+ video_path, self.video_resolution_buckets
+ )
+
+ @override
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
+ return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py
new file mode 100644
index 0000000..ba7bddf
--- /dev/null
+++ b/finetune/datasets/utils.py
@@ -0,0 +1,147 @@
+import torch
+import cv2
+
+from typing import List, Tuple
+from pathlib import Path
+from torchvision import transforms
+from torchvision.transforms.functional import resize
+
+# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
+# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
+import decord # isort:skip
+
+decord.bridge.set_bridge("torch")
+
+
+########## loaders ##########
+
+def load_prompts(prompt_path: Path) -> List[str]:
+ with open(prompt_path, "r", encoding="utf-8") as file:
+ return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
+
+
+def load_videos(video_path: Path) -> List[Path]:
+ with open(video_path, "r", encoding="utf-8") as file:
+ return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
+
+
+def load_images(image_path: Path) -> List[Path]:
+ with open(image_path, "r", encoding="utf-8") as file:
+ return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
+
+
+########## preprocessors ##########
+
+def preprocess_image_with_resize(
+ image_path: Path | str,
+ height: int,
+ width: int,
+) -> torch.Tensor:
+ """
+ Loads and resizes a single image.
+
+ Args:
+ image_path: Path to the image file.
+ height: Target height for resizing.
+ width: Target width for resizing.
+
+ Returns:
+ torch.Tensor: Image tensor with shape [C, H, W] where:
+ C = number of channels (3 for RGB)
+ H = height
+ W = width
+ """
+ if isinstance(image_path, str):
+ image_path = Path(image_path)
+ image = cv2.imread(image_path.as_posix())
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = cv2.resize(image, (width, height))
+ image = torch.from_numpy(image).float()
+ image = image.permute(2, 0, 1).contiguous()
+ return image
+
+
+def preprocess_video_with_resize(
+ video_path: Path | str,
+ max_num_frames: int,
+ height: int,
+ width: int,
+) -> torch.Tensor:
+ """
+ Loads and resizes a single video.
+
+ The function processes the video through these steps:
+ 1. If video frame count > max_num_frames, downsample frames evenly
+ 2. If video dimensions don't match (height, width), resize frames
+
+ Args:
+ video_path: Path to the video file.
+ max_num_frames: Maximum number of frames to keep.
+ height: Target height for resizing.
+ width: Target width for resizing.
+
+ Returns:
+ A torch.Tensor with shape [F, C, H, W] where:
+ F = number of frames
+ C = number of channels (3 for RGB)
+ H = height
+ W = width
+ """
+ if isinstance(video_path, str):
+ video_path = Path(video_path)
+ video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
+ video_num_frames = len(video_reader)
+ if video_num_frames < max_num_frames:
+ raise ValueError(f"video's frames is less than {max_num_frames}.")
+
+ indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
+ frames = video_reader.get_batch(indices)
+ frames = frames[: max_num_frames].float()
+ frames = frames.permute(0, 3, 1, 2).contiguous()
+ return frames
+
+
+def preprocess_video_with_buckets(
+ video_path: Path,
+ resolution_buckets: List[Tuple[int, int, int]],
+) -> torch.Tensor:
+ """
+ Args:
+ video_path: Path to the video file.
+ resolution_buckets: List of tuples (num_frames, height, width) representing
+ available resolution buckets.
+
+ Returns:
+ torch.Tensor: Video tensor with shape [F, C, H, W] where:
+ F = number of frames
+ C = number of channels (3 for RGB)
+ H = height
+ W = width
+
+ The function processes the video through these steps:
+ 1. Finds nearest frame bucket <= video frame count
+ 2. Downsamples frames evenly to match bucket size
+ 3. Finds nearest resolution bucket based on dimensions
+ 4. Resizes frames to match bucket resolution
+ """
+ video_reader = decord.VideoReader(uri=video_path.as_posix())
+ video_num_frames = len(video_reader)
+ resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
+ if len(resolution_buckets) == 0:
+ raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
+
+ nearest_frame_bucket = min(
+ resolution_buckets,
+ key=lambda bucket: video_num_frames - bucket[0],
+ default=1,
+ )[0]
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
+ frames = video_reader.get_batch(frame_indices)
+ frames = frames[:nearest_frame_bucket].float()
+ frames = frames.permute(0, 3, 1, 2).contiguous()
+
+ nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
+ nearest_res = (nearest_res[1], nearest_res[2])
+ frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
+
+ return frames
\ No newline at end of file
diff --git a/finetune/ds_config.json b/finetune/ds_config.json
deleted file mode 100644
index 9a708a7..0000000
--- a/finetune/ds_config.json
+++ /dev/null
@@ -1,20 +0,0 @@
-{
- "scheduler": {
- "type": "WarmupDecayLR",
- "params": {
- "warmup_min_lr": "auto",
- "warmup_max_lr": "auto",
- "warmup_num_steps": "auto",
- "total_num_steps": "auto"
- }
- },
- "zero_optimization": {
- "stage": 2,
- "allgather_partitions": true,
- "allgather_bucket_size": 2e8,
- "overlap_comm": true,
- "reduce_scatter": true,
- "reduce_bucket_size": 1e8,
- "contiguous_gradients": true
- }
-}
\ No newline at end of file
diff --git a/finetune/finetune_multi_rank.sh b/finetune/finetune_multi_rank.sh
deleted file mode 100644
index f6c34de..0000000
--- a/finetune/finetune_multi_rank.sh
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/bin/bash
-
-export MODEL_PATH="THUDM/CogVideoX-2b"
-export CACHE_PATH="~/.cache"
-export DATASET_PATH="Disney-VideoGeneration-Dataset"
-export OUTPUT_PATH="cogvideox-lora-multi-node"
-export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
-export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
-
-# max batch-size is 2.
-accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
- train_cogvideox_lora.py \
- --gradient_checkpointing \
- --pretrained_model_name_or_path $MODEL_PATH \
- --cache_dir $CACHE_PATH \
- --enable_tiling \
- --enable_slicing \
- --instance_data_root $DATASET_PATH \
- --caption_column prompt.txt \
- --video_column videos.txt \
- --validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
- --validation_prompt_separator ::: \
- --num_validation_videos 1 \
- --validation_epochs 100 \
- --seed 42 \
- --rank 128 \
- --lora_alpha 64 \
- --mixed_precision bf16 \
- --output_dir $OUTPUT_PATH \
- --height 480 \
- --width 720 \
- --fps 8 \
- --max_num_frames 49 \
- --skip_frames_start 0 \
- --skip_frames_end 0 \
- --train_batch_size 1 \
- --num_train_epochs 30 \
- --checkpointing_steps 1000 \
- --gradient_accumulation_steps 1 \
- --learning_rate 1e-3 \
- --lr_scheduler cosine_with_restarts \
- --lr_warmup_steps 200 \
- --lr_num_cycles 1 \
- --enable_slicing \
- --enable_tiling \
- --gradient_checkpointing \
- --optimizer AdamW \
- --adam_beta1 0.9 \
- --adam_beta2 0.95 \
- --max_grad_norm 1.0 \
- --allow_tf32 \
- --report_to wandb
\ No newline at end of file
diff --git a/finetune/finetune_single_rank.sh b/finetune/finetune_single_rank.sh
deleted file mode 100644
index 8b45876..0000000
--- a/finetune/finetune_single_rank.sh
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/bin/bash
-
-export MODEL_PATH="THUDM/CogVideoX-5b"
-export CACHE_PATH="~/.cache"
-export DATASET_PATH="Disney-VideoGeneration-Dataset"
-export OUTPUT_PATH="cogvideox-lora-single-node"
-export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
-export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
-
-# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
-accelerate launch --config_file accelerate_config_machine_single.yaml \
- train_cogvideox_lora.py \
- --gradient_checkpointing \
- --pretrained_model_name_or_path $MODEL_PATH \
- --cache_dir $CACHE_PATH \
- --enable_tiling \
- --enable_slicing \
- --instance_data_root $DATASET_PATH \
- --caption_column prompt.txt \
- --video_column videos.txt \
- --validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
- --validation_prompt_separator ::: \
- --num_validation_videos 1 \
- --validation_epochs 100 \
- --seed 42 \
- --rank 128 \
- --lora_alpha 64 \
- --mixed_precision bf16 \
- --output_dir $OUTPUT_PATH \
- --height 480 \
- --width 720 \
- --fps 8 \
- --max_num_frames 49 \
- --skip_frames_start 0 \
- --skip_frames_end 0 \
- --train_batch_size 1 \
- --num_train_epochs 30 \
- --checkpointing_steps 1000 \
- --gradient_accumulation_steps 1 \
- --learning_rate 1e-3 \
- --lr_scheduler cosine_with_restarts \
- --lr_warmup_steps 200 \
- --lr_num_cycles 1 \
- --enable_slicing \
- --enable_tiling \
- --gradient_checkpointing \
- --optimizer AdamW \
- --adam_beta1 0.9 \
- --adam_beta2 0.95 \
- --max_grad_norm 1.0 \
- --allow_tf32 \
- --report_to wandb
\ No newline at end of file
diff --git a/finetune/hostfile.txt b/finetune/hostfile.txt
deleted file mode 100644
index d0b8045..0000000
--- a/finetune/hostfile.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-node1 slots=8
-node2 slots=8
\ No newline at end of file
diff --git a/finetune/models/__init__.py b/finetune/models/__init__.py
new file mode 100644
index 0000000..b315ff5
--- /dev/null
+++ b/finetune/models/__init__.py
@@ -0,0 +1,12 @@
+import importlib
+from pathlib import Path
+
+
+package_dir = Path(__file__).parent
+
+for subdir in package_dir.iterdir():
+ if subdir.is_dir() and not subdir.name.startswith('_'):
+ for module_path in subdir.glob('*.py'):
+ module_name = module_path.stem
+ full_module_name = f".{subdir.name}.{module_name}"
+ importlib.import_module(full_module_name, package=__name__)
diff --git a/finetune/models/cogvideox1dot5_i2v/lora_trainer.py b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py
new file mode 100644
index 0000000..09d4b70
--- /dev/null
+++ b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py
@@ -0,0 +1,9 @@
+from ..utils import register
+from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
+
+
+class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
+ pass
+
+
+register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer)
diff --git a/finetune/models/cogvideox1dot5_t2v/lora_trainer.py b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py
new file mode 100644
index 0000000..79504bc
--- /dev/null
+++ b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py
@@ -0,0 +1,9 @@
+from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
+from ..utils import register
+
+
+class CogVideoX1dot5T2VLoraTrainer(CogVideoXT2VLoraTrainer):
+ pass
+
+
+register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer)
diff --git a/finetune/models/cogvideox_i2v/lora_trainer.py b/finetune/models/cogvideox_i2v/lora_trainer.py
new file mode 100644
index 0000000..442f769
--- /dev/null
+++ b/finetune/models/cogvideox_i2v/lora_trainer.py
@@ -0,0 +1,240 @@
+import torch
+
+from typing_extensions import override
+from typing import Any, Dict, List, Tuple
+from PIL import Image
+
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers import (
+ CogVideoXImageToVideoPipeline,
+ CogVideoXTransformer3DModel,
+ AutoencoderKLCogVideoX,
+ CogVideoXDPMScheduler,
+)
+
+from finetune.trainer import Trainer
+from finetune.schemas import Components
+from finetune.utils import unwrap_model
+from ..utils import register
+
+
+class CogVideoXI2VLoraTrainer(Trainer):
+
+ @override
+ def load_components(self) -> Dict[str, Any]:
+ components = Components()
+ model_path = str(self.args.model_path)
+
+ components.pipeline_cls = CogVideoXImageToVideoPipeline
+
+ components.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, subfolder="tokenizer"
+ )
+
+ components.text_encoder = T5EncoderModel.from_pretrained(
+ model_path, subfolder="text_encoder"
+ )
+
+ components.transformer = CogVideoXTransformer3DModel.from_pretrained(
+ model_path, subfolder="transformer"
+ )
+
+ components.vae = AutoencoderKLCogVideoX.from_pretrained(
+ model_path, subfolder="vae"
+ )
+
+ components.scheduler = CogVideoXDPMScheduler.from_pretrained(
+ model_path, subfolder="scheduler"
+ )
+
+ return components
+
+ @override
+ def encode_video(self, video: torch.Tensor) -> torch.Tensor:
+ # shape of input video: [B, C, F, H, W]
+ vae = self.components.vae
+ video = video.to(vae.device, dtype=vae.dtype)
+ latent_dist = vae.encode(video).latent_dist
+ latent = latent_dist.sample() * vae.config.scaling_factor
+ return latent
+
+ @override
+ def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
+ ret = {
+ "encoded_videos": [],
+ "prompt_token_ids": [],
+ "images": []
+ }
+
+ for sample in samples:
+ encoded_video = sample["encoded_video"]
+ prompt = sample["prompt"]
+ image = sample["image"]
+
+ # tokenize prompt
+ text_inputs = self.components.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.state.transformer_config.max_text_seq_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ ret["encoded_videos"].append(encoded_video)
+ ret["prompt_token_ids"].append(text_input_ids[0])
+ ret["images"].append(image)
+
+ ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
+ ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
+ ret["images"] = torch.stack(ret["images"])
+
+ return ret
+
+ @override
+ def compute_loss(self, batch) -> torch.Tensor:
+ prompt_token_ids = batch["prompt_token_ids"]
+ latent = batch["encoded_videos"]
+ images = batch["images"]
+
+ batch_size, num_channels, num_frames, height, width = latent.shape
+
+ # Get prompt embeddings
+ prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
+
+ # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
+ images = images.unsqueeze(2)
+ # Add noise to images
+ image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
+ image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
+ noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
+ image_latent_dist = self.components.vae.encode(noisy_images).latent_dist
+ image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
+
+ # Sample a random timestep for each sample
+ timesteps = torch.randint(
+ 0, self.components.scheduler.config.num_train_timesteps,
+ (batch_size,), device=self.accelerator.device
+ )
+ timesteps = timesteps.long()
+
+ # from [B, C, F, H, W] to [B, F, C, H, W]
+ latent = latent.permute(0, 2, 1, 3, 4)
+ image_latents = image_latents.permute(0, 2, 1, 3, 4)
+ assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:])
+
+ # Padding image_latents to the same frame number as latent
+ padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
+ latent_padding = image_latents.new_zeros(padding_shape)
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
+
+ # Add noise to latent
+ noise = torch.randn_like(latent)
+ latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
+
+ # Concatenate latent and image_latents in the channel dimension
+ latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
+
+ # Prepare rotary embeds
+ vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
+ transformer_config = self.state.transformer_config
+ rotary_emb = (
+ self.prepare_rotary_positional_embeddings(
+ height=height * vae_scale_factor_spatial,
+ width=width * vae_scale_factor_spatial,
+ num_frames=num_frames,
+ transformer_config=transformer_config,
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
+ device=self.accelerator.device,
+ )
+ if transformer_config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # Predict noise
+ ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
+ predicted_noise = self.components.transformer(
+ hidden_states=latent_img_noisy,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timesteps,
+ ofs=ofs_emb,
+ image_rotary_emb=rotary_emb,
+ return_dict=False,
+ )[0]
+
+ # Denoise
+ latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
+
+ alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
+ weights = 1 / (1 - alphas_cumprod)
+ while len(weights.shape) < len(latent_pred.shape):
+ weights = weights.unsqueeze(-1)
+
+ loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
+ loss = loss.mean()
+
+ return loss
+
+ @override
+ def validation_step(
+ self, eval_data: Dict[str, Any]
+ ) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
+ """
+ Return the data that needs to be saved. For videos, the data format is List[PIL],
+ and for images, the data format is PIL
+ """
+ prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
+
+ pipe = self.components.pipeline_cls(
+ tokenizer=self.components.tokenizer,
+ text_encoder=self.components.text_encoder,
+ vae=self.components.vae,
+ transformer=unwrap_model(self.accelerator, self.components.transformer),
+ scheduler=self.components.scheduler
+ )
+ video_generate = pipe(
+ num_frames=self.state.train_frames,
+ height=self.state.train_height,
+ width=self.state.train_width,
+ prompt=prompt,
+ image=image,
+ generator=self.state.generator
+ ).frames[0]
+ return [("video", video_generate)]
+
+ def prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ transformer_config: Dict,
+ vae_scale_factor_spatial: int,
+ device: torch.device
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
+ grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
+
+ if transformer_config.patch_size_t is None:
+ base_num_frames = num_frames
+ else:
+ base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=transformer_config.attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(grid_height, grid_width),
+ device=device,
+ )
+
+ return freqs_cos, freqs_sin
+
+
+register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
\ No newline at end of file
diff --git a/finetune/models/cogvideox_t2v/lora_trainer.py b/finetune/models/cogvideox_t2v/lora_trainer.py
new file mode 100644
index 0000000..2e92486
--- /dev/null
+++ b/finetune/models/cogvideox_t2v/lora_trainer.py
@@ -0,0 +1,214 @@
+import torch
+
+from typing_extensions import override
+from typing import Any, Dict, List, Tuple
+
+from PIL import Image
+
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers import (
+ CogVideoXPipeline,
+ CogVideoXTransformer3DModel,
+ AutoencoderKLCogVideoX,
+ CogVideoXDPMScheduler,
+)
+
+from finetune.trainer import Trainer
+from finetune.schemas import Components
+from finetune.utils import unwrap_model
+from ..utils import register
+
+
+class CogVideoXT2VLoraTrainer(Trainer):
+
+ @override
+ def load_components(self) -> Components:
+ components = Components()
+ model_path = str(self.args.model_path)
+
+ components.pipeline_cls = CogVideoXPipeline
+
+ components.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, subfolder="tokenizer"
+ )
+
+ components.text_encoder = T5EncoderModel.from_pretrained(
+ model_path, subfolder="text_encoder"
+ )
+
+ components.transformer = CogVideoXTransformer3DModel.from_pretrained(
+ model_path, subfolder="transformer"
+ )
+
+ components.vae = AutoencoderKLCogVideoX.from_pretrained(
+ model_path, subfolder="vae"
+ )
+
+ components.scheduler = CogVideoXDPMScheduler.from_pretrained(
+ model_path, subfolder="scheduler"
+ )
+
+ return components
+
+ @override
+ def encode_video(self, video: torch.Tensor) -> torch.Tensor:
+ # shape of input video: [B, C, F, H, W]
+ vae = self.components.vae
+ video = video.to(vae.device, dtype=vae.dtype)
+ latent_dist = vae.encode(video).latent_dist
+ latent = latent_dist.sample() * vae.config.scaling_factor
+ return latent
+
+ @override
+ def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
+ ret = {
+ "encoded_videos": [],
+ "prompt_token_ids": []
+ }
+
+ for sample in samples:
+ encoded_video = sample["encoded_video"]
+ prompt = sample["prompt"]
+
+ # tokenize prompt
+ text_inputs = self.components.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=226,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ ret["encoded_videos"].append(encoded_video)
+ ret["prompt_token_ids"].append(text_input_ids[0])
+
+ ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
+ ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
+
+ return ret
+
+ @override
+ def compute_loss(self, batch) -> torch.Tensor:
+ prompt_token_ids = batch["prompt_token_ids"]
+ latent = batch["encoded_videos"]
+
+ batch_size, num_channels, num_frames, height, width = latent.shape
+
+ # Get prompt embeddings
+ prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
+ assert prompt_embeds.requires_grad is False
+
+ # Sample a random timestep for each sample
+ timesteps = torch.randint(
+ 0, self.components.scheduler.config.num_train_timesteps,
+ (batch_size,), device=self.accelerator.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to latent
+ latent = latent.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
+ noise = torch.randn_like(latent)
+ latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps)
+
+ # Prepare rotary embeds
+ vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
+ transformer_config = self.state.transformer_config
+ rotary_emb = (
+ self.prepare_rotary_positional_embeddings(
+ height=height * vae_scale_factor_spatial,
+ width=width * vae_scale_factor_spatial,
+ num_frames=num_frames,
+ transformer_config=transformer_config,
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
+ device=self.accelerator.device,
+ )
+ if transformer_config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # Predict noise
+ predicted_noise = self.components.transformer(
+ hidden_states=latent_added_noise,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timesteps,
+ image_rotary_emb=rotary_emb,
+ return_dict=False,
+ )[0]
+
+ # Denoise
+ latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps)
+
+ alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
+ weights = 1 / (1 - alphas_cumprod)
+ while len(weights.shape) < len(latent_pred.shape):
+ weights = weights.unsqueeze(-1)
+
+ loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
+ loss = loss.mean()
+
+ return loss
+
+ @override
+ def validation_step(
+ self, eval_data: Dict[str, Any]
+ ) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
+ """
+ Return the data that needs to be saved. For videos, the data format is List[PIL],
+ and for images, the data format is PIL
+ """
+ prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
+
+ pipe = self.components.pipeline_cls(
+ tokenizer=self.components.tokenizer,
+ text_encoder=self.components.text_encoder,
+ vae=self.components.vae,
+ transformer=unwrap_model(self.accelerator, self.components.transformer),
+ scheduler=self.components.scheduler
+ )
+ video_generate = pipe(
+ num_frames=self.state.train_frames,
+ height=self.state.train_height,
+ width=self.state.train_width,
+ prompt=prompt,
+ generator=self.state.generator
+ ).frames[0]
+ return [("video", video_generate)]
+
+ def prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ transformer_config: Dict,
+ vae_scale_factor_spatial: int,
+ device: torch.device
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
+ grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
+
+ if transformer_config.patch_size_t is None:
+ base_num_frames = num_frames
+ else:
+ base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=transformer_config.attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(grid_height, grid_width),
+ device=device,
+ )
+
+ return freqs_cos, freqs_sin
+
+
+register("cogvideox-t2v", "lora", CogVideoXT2VLoraTrainer)
diff --git a/finetune/models/utils.py b/finetune/models/utils.py
new file mode 100644
index 0000000..fd5a455
--- /dev/null
+++ b/finetune/models/utils.py
@@ -0,0 +1,62 @@
+from typing import Literal, Dict
+
+from finetune.trainer import Trainer
+
+
+SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
+
+
+def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
+ """Register a model and its associated functions for a specific training type.
+
+ Args:
+ model_name (str): Name of the model to register (e.g. "cogvideox-5b")
+ training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
+ trainer_cls (Trainer): Trainer class to register.
+ """
+
+ # Check if model_name exists in SUPPORTED_MODELS
+ if model_name not in SUPPORTED_MODELS:
+ SUPPORTED_MODELS[model_name] = {}
+ else:
+ raise ValueError(f"Model {model_name} already exists")
+
+ # Check if training_type exists for this model
+ if training_type not in SUPPORTED_MODELS[model_name]:
+ SUPPORTED_MODELS[model_name][training_type] = {}
+ else:
+ raise ValueError(f"Training type {training_type} already exists for model {model_name}")
+
+ SUPPORTED_MODELS[model_name][training_type] = trainer_cls
+
+
+def show_supported_models():
+ """Print all currently supported models and their training types."""
+
+ print("\nSupported Models:")
+ print("================")
+
+ for model_name, training_types in SUPPORTED_MODELS.items():
+ print(f"\n{model_name}")
+ print("-" * len(model_name))
+ for training_type in training_types:
+ print(f" • {training_type}")
+
+
+def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
+ """Get the trainer class for a specific model and training type."""
+ if model_type not in SUPPORTED_MODELS:
+ print(f"\nModel '{model_type}' is not supported.")
+ print("\nSupported models are:")
+ for supported_model in SUPPORTED_MODELS:
+ print(f" • {supported_model}")
+ raise ValueError(f"Model '{model_type}' is not supported")
+
+ if training_type not in SUPPORTED_MODELS[model_type]:
+ print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
+ print(f"\nSupported training types for '{model_type}' are:")
+ for supported_type in SUPPORTED_MODELS[model_type]:
+ print(f" • {supported_type}")
+ raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
+
+ return SUPPORTED_MODELS[model_type][training_type]
diff --git a/finetune/schemas/__init__.py b/finetune/schemas/__init__.py
new file mode 100644
index 0000000..76a6bf8
--- /dev/null
+++ b/finetune/schemas/__init__.py
@@ -0,0 +1,5 @@
+from .args import Args
+from .state import State
+from .components import Components
+
+__all__ = ["Args", "State", "Components"]
\ No newline at end of file
diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py
new file mode 100644
index 0000000..dca76e3
--- /dev/null
+++ b/finetune/schemas/args.py
@@ -0,0 +1,211 @@
+import datetime
+import argparse
+from typing import Dict, Any, Literal, List, Tuple
+from pydantic import BaseModel, field_validator, ValidationInfo
+
+from pathlib import Path
+
+
+class Args(BaseModel):
+ ########## Model ##########
+ model_path: Path
+ model_name: str
+ model_type: Literal["i2v", "t2v"]
+ training_type: Literal["lora", "sft"] = "lora"
+
+ ########## Output ##########
+ output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now()))
+ report_to: Literal["tensorboard", "wandb", "all"] | None = None
+ tracker_name: str = "finetrainer-cogvideo"
+
+ ########## Data ###########
+ data_root: Path
+ caption_column: Path
+ image_column: Path | None = None
+ video_column: Path
+
+ ########## Training #########
+ resume_from_checkpoint: Path | None = None
+
+ seed: int | None = None
+ train_epochs: int
+ train_steps: int | None = None
+ checkpointing_steps: int = 200
+ checkpointing_limit: int = 10
+
+ batch_size: int
+ gradient_accumulation_steps: int = 1
+
+ train_resolution: Tuple[int, int, int] # shape: (frames, height, width)
+
+ #### deprecated args: video_resolution_buckets
+ # if use bucket for training, should not be None
+ # Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset
+ # Note2: For cogvideox, cogvideox1.5
+ # The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
+ # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
+ # video_resolution_buckets: List[Tuple[int, int, int]] | None = None
+
+ mixed_precision: Literal["no", "fp16", "bf16"]
+
+ learning_rate: float = 2e-5
+ optimizer: str = "adamw"
+ beta1: float = 0.9
+ beta2: float = 0.95
+ beta3: float = 0.98
+ epsilon: float = 1e-8
+ weight_decay: float = 1e-4
+ max_grad_norm: float = 1.0
+
+ lr_scheduler: str = "constant_with_warmup"
+ lr_warmup_steps: int = 100
+ lr_num_cycles: int = 1
+ lr_power: float = 1.0
+
+ num_workers: int = 8
+ pin_memory: bool = True
+
+ gradient_checkpointing: bool = True
+ enable_slicing: bool = True
+ enable_tiling: bool = True
+ nccl_timeout: int = 1800
+
+ ########## Lora ##########
+ rank: int = 128
+ lora_alpha: int = 64
+ target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"]
+
+ ########## Validation ##########
+ do_validation: bool = False
+ validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
+ validation_dir: Path | None # if set do_validation, should not be None
+ validation_prompts: str | None # if set do_validation, should not be None
+ validation_images: str | None # if set do_validation and model_type == i2v, should not be None
+ validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
+ gen_fps: int = 15
+
+ #### deprecated args: gen_video_resolution
+ # 1. If set do_validation, should not be None
+ # 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning
+ # or the resolution recommended by the model
+ # 3. Note: For cogvideox, cogvideox1.5
+ # The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
+ # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
+ # gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width)
+
+ @field_validator("image_column")
+ def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
+ values = info.data
+ if values.get("model_type") == "i2v" and not v:
+ raise ValueError("image_column must be specified when using i2v model")
+ return v
+
+ @field_validator("validation_dir", "validation_prompts")
+ def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any:
+ values = info.data
+ if values.get("do_validation") and not v:
+ field_name = info.field_name
+ raise ValueError(f"{field_name} must be specified when do_validation is True")
+ return v
+
+ @field_validator("validation_images")
+ def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
+ values = info.data
+ if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
+ raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
+ return v
+
+ @field_validator("validation_videos")
+ def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
+ values = info.data
+ if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
+ raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v")
+ return v
+
+ @field_validator("validation_steps")
+ def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None:
+ values = info.data
+ if values.get("do_validation"):
+ if v is None:
+ raise ValueError("validation_steps must be specified when do_validation is True")
+ if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
+ raise ValueError("validation_steps must be a multiple of checkpointing_steps")
+ return v
+
+
+ @classmethod
+ def parse_args(cls):
+ """Parse command line arguments and return Args instance"""
+ parser = argparse.ArgumentParser()
+ # Required arguments
+ parser.add_argument("--model_path", type=str, required=True)
+ parser.add_argument("--model_name", type=str, required=True)
+ parser.add_argument("--model_type", type=str, required=True)
+ parser.add_argument("--training_type", type=str, required=True)
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument("--data_root", type=str, required=True)
+ parser.add_argument("--caption_column", type=str, required=True)
+ parser.add_argument("--video_column", type=str, required=True)
+ parser.add_argument("--train_resolution", type=str, required=True)
+ parser.add_argument("--report_to", type=str, required=True)
+
+ # Training hyperparameters
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--train_epochs", type=int, default=10)
+ parser.add_argument("--train_steps", type=int, default=None)
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
+ parser.add_argument("--optimizer", type=str, default="adamw")
+ parser.add_argument("--beta1", type=float, default=0.9)
+ parser.add_argument("--beta2", type=float, default=0.95)
+ parser.add_argument("--beta3", type=float, default=0.98)
+ parser.add_argument("--epsilon", type=float, default=1e-8)
+ parser.add_argument("--weight_decay", type=float, default=1e-4)
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
+
+ # Learning rate scheduler
+ parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup")
+ parser.add_argument("--lr_warmup_steps", type=int, default=100)
+ parser.add_argument("--lr_num_cycles", type=int, default=1)
+ parser.add_argument("--lr_power", type=float, default=1.0)
+
+ # Data loading
+ parser.add_argument("--num_workers", type=int, default=8)
+ parser.add_argument("--pin_memory", type=bool, default=True)
+ parser.add_argument("--image_column", type=str, default=None)
+
+ # Model configuration
+ parser.add_argument("--mixed_precision", type=str, default="no")
+ parser.add_argument("--gradient_checkpointing", type=bool, default=True)
+ parser.add_argument("--enable_slicing", type=bool, default=True)
+ parser.add_argument("--enable_tiling", type=bool, default=True)
+ parser.add_argument("--nccl_timeout", type=int, default=1800)
+
+ # LoRA parameters
+ parser.add_argument("--rank", type=int, default=128)
+ parser.add_argument("--lora_alpha", type=int, default=64)
+ parser.add_argument("--target_modules", type=str, nargs="+",
+ default=["to_q", "to_k", "to_v", "to_out.0"])
+
+ # Checkpointing
+ parser.add_argument("--checkpointing_steps", type=int, default=200)
+ parser.add_argument("--checkpointing_limit", type=int, default=10)
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
+
+ # Validation
+ parser.add_argument("--do_validation", type=bool, default=False)
+ parser.add_argument("--validation_steps", type=int, default=None)
+ parser.add_argument("--validation_dir", type=str, default=None)
+ parser.add_argument("--validation_prompts", type=str, default=None)
+ parser.add_argument("--validation_images", type=str, default=None)
+ parser.add_argument("--validation_videos", type=str, default=None)
+ parser.add_argument("--gen_fps", type=int, default=15)
+
+ args = parser.parse_args()
+
+ # Convert video_resolution_buckets string to list of tuples
+ frames, height, width = args.train_resolution.split("x")
+ args.train_resolution = (int(frames), int(height), int(width))
+
+ return cls(**vars(args))
diff --git a/finetune/schemas/components.py b/finetune/schemas/components.py
new file mode 100644
index 0000000..2d3fef5
--- /dev/null
+++ b/finetune/schemas/components.py
@@ -0,0 +1,27 @@
+from typing import Any
+from pydantic import BaseModel
+
+
+class Components(BaseModel):
+ # pipeline cls
+ pipeline_cls: Any = None
+
+ # Tokenizers
+ tokenizer: Any = None
+ tokenizer_2: Any = None
+ tokenizer_3: Any = None
+
+ # Text encoders
+ text_encoder: Any = None
+ text_encoder_2: Any = None
+ text_encoder_3: Any = None
+
+ # Autoencoder
+ vae: Any = None
+
+ # Denoiser
+ transformer: Any = None
+ unet: Any = None
+
+ # Scheduler
+ scheduler: Any = None
diff --git a/finetune/schemas/state.py b/finetune/schemas/state.py
new file mode 100644
index 0000000..9d5363c
--- /dev/null
+++ b/finetune/schemas/state.py
@@ -0,0 +1,26 @@
+import torch
+
+from pathlib import Path
+from typing import List, Dict, Any
+from pydantic import BaseModel, field_validator
+
+class State(BaseModel):
+ model_config = {"arbitrary_types_allowed": True}
+
+ train_frames: int
+ train_height: int
+ train_width: int
+
+ transformer_config: Dict[str, Any] = None
+
+ weight_dtype: torch.dtype = torch.float32
+ num_trainable_parameters: int = 0
+ overwrote_max_train_steps: bool = False
+ num_update_steps_per_epoch: int = 0
+ total_batch_size_count: int = 0
+
+ generator: torch.Generator | None = None
+
+ validation_prompts: List[str] = []
+ validation_images: List[Path | None] = []
+ validation_videos: List[Path | None] = []
diff --git a/finetune/scripts/extract_images.py b/finetune/scripts/extract_images.py
new file mode 100644
index 0000000..a5cabae
--- /dev/null
+++ b/finetune/scripts/extract_images.py
@@ -0,0 +1,52 @@
+import argparse
+import os
+from pathlib import Path
+import cv2
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
+ return parser.parse_args()
+
+args = parse_args()
+
+# Create data/images directory if it doesn't exist
+data_dir = Path(args.datadir)
+image_dir = data_dir / "images"
+image_dir.mkdir(exist_ok=True)
+
+# Read videos.txt
+videos_file = data_dir / "videos.txt"
+with open(videos_file, "r") as f:
+ video_paths = [line.strip() for line in f.readlines() if line.strip()]
+
+# Process each video file and collect image paths
+image_paths = []
+for video_rel_path in video_paths:
+ video_path = data_dir / video_rel_path
+
+ # Open video
+ cap = cv2.VideoCapture(str(video_path))
+
+ # Read first frame
+ ret, frame = cap.read()
+ if not ret:
+ print(f"Failed to read video: {video_path}")
+ continue
+
+ # Save frame as PNG with same name as video
+ image_name = f"images/{video_path.stem}.png"
+ image_path = data_dir / image_name
+ cv2.imwrite(str(image_path), frame)
+
+ # Release video capture
+ cap.release()
+
+ print(f"Extracted first frame from {video_path} to {image_path}")
+ image_paths.append(image_name)
+
+# Write images.txt
+images_file = data_dir / "images.txt"
+with open(images_file, "w") as f:
+ for path in image_paths:
+ f.write(f"{path}\n")
\ No newline at end of file
diff --git a/finetune/train.py b/finetune/train.py
new file mode 100644
index 0000000..5f49f4b
--- /dev/null
+++ b/finetune/train.py
@@ -0,0 +1,18 @@
+import sys
+from pathlib import Path
+
+sys.path.append(str(Path(__file__).parent.parent))
+
+from finetune.schemas import Args
+from finetune.models.utils import get_model_cls
+
+
+def main():
+ args = Args.parse_args()
+ trainer_cls = get_model_cls(args.model_name, args.training_type)
+ trainer = trainer_cls(args)
+ trainer.fit()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py
deleted file mode 100644
index abf245f..0000000
--- a/finetune/train_cogvideox_image_to_video_lora.py
+++ /dev/null
@@ -1,1689 +0,0 @@
-# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import argparse
-import logging
-import math
-import os
-import random
-import shutil
-from datetime import timedelta
-from pathlib import Path
-from typing import List, Optional, Tuple, Union
-
-import torch
-import transformers
-from accelerate import Accelerator
-from accelerate.logging import get_logger
-from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed
-from huggingface_hub import create_repo, upload_folder
-from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
-from torch.utils.data import DataLoader, Dataset
-from torchvision import transforms
-from tqdm.auto import tqdm
-from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
-
-import diffusers
-from diffusers import (
- AutoencoderKLCogVideoX,
- CogVideoXDPMScheduler,
- CogVideoXImageToVideoPipeline,
- CogVideoXTransformer3DModel,
-)
-from diffusers.models.embeddings import get_3d_rotary_pos_embed
-from diffusers.optimization import get_scheduler
-from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
-from diffusers.training_utils import cast_training_params, free_memory
-from diffusers.utils import (
- check_min_version,
- convert_unet_state_dict_to_peft,
- export_to_video,
- is_wandb_available,
- load_image,
-)
-from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
-from diffusers.utils.torch_utils import is_compiled_module
-from torchvision.transforms.functional import center_crop, resize
-from torchvision.transforms import InterpolationMode
-import torchvision.transforms as TT
-import numpy as np
-from diffusers.image_processor import VaeImageProcessor
-
-
-if is_wandb_available():
- import wandb
-
-# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
-
-logger = get_logger(__name__)
-
-
-def get_args():
- parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
-
- # Model information
- parser.add_argument(
- "--pretrained_model_name_or_path",
- type=str,
- default=None,
- required=True,
- help="Path to pretrained model or model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--revision",
- type=str,
- default=None,
- required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
- parser.add_argument(
- "--cache_dir",
- type=str,
- default=None,
- help="The directory where the downloaded models and datasets will be stored.",
- )
-
- # Dataset information
- parser.add_argument(
- "--dataset_name",
- type=str,
- default=None,
- help=(
- "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
- " or to a folder containing files that 🤗 Datasets can understand."
- ),
- )
- parser.add_argument(
- "--dataset_config_name",
- type=str,
- default=None,
- help="The config of the Dataset, leave as None if there's only one config.",
- )
- parser.add_argument(
- "--instance_data_root",
- type=str,
- default=None,
- help=("A folder containing the training data."),
- )
- parser.add_argument(
- "--video_column",
- type=str,
- default="video",
- help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.",
- )
- parser.add_argument(
- "--caption_column",
- type=str,
- default="text",
- help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.",
- )
- parser.add_argument(
- "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided."
- )
- parser.add_argument(
- "--dataloader_num_workers",
- type=int,
- default=0,
- help=(
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
- ),
- )
-
- # Validation
- parser.add_argument(
- "--validation_prompt",
- type=str,
- default=None,
- help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
- )
- parser.add_argument(
- "--validation_images",
- type=str,
- default=None,
- help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
- )
- parser.add_argument(
- "--validation_prompt_separator",
- type=str,
- default=":::",
- help="String that separates multiple validation prompts",
- )
- parser.add_argument(
- "--num_validation_videos",
- type=int,
- default=1,
- help="Number of videos that should be generated during validation per `validation_prompt`.",
- )
- parser.add_argument(
- "--validation_epochs",
- type=int,
- default=50,
- help=(
- "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`."
- ),
- )
- parser.add_argument(
- "--guidance_scale",
- type=float,
- default=6,
- help="The guidance scale to use while sampling validation videos.",
- )
- parser.add_argument(
- "--use_dynamic_cfg",
- action="store_true",
- default=False,
- help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
- )
-
- # Training information
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
- parser.add_argument(
- "--rank",
- type=int,
- default=128,
- help=("The dimension of the LoRA update matrices."),
- )
- parser.add_argument(
- "--lora_alpha",
- type=float,
- default=128,
- help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
- )
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default=None,
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default="cogvideox-i2v-lora",
- help="The output directory where the model predictions and checkpoints will be written.",
- )
- parser.add_argument(
- "--height",
- type=int,
- default=480,
- help="All input videos are resized to this height.",
- )
- parser.add_argument(
- "--width",
- type=int,
- default=720,
- help="All input videos are resized to this width.",
- )
- parser.add_argument(
- "--video_reshape_mode",
- type=str,
- default="center",
- help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
- )
- parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
- parser.add_argument(
- "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
- )
- parser.add_argument(
- "--skip_frames_start",
- type=int,
- default=0,
- help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
- )
- parser.add_argument(
- "--skip_frames_end",
- type=int,
- default=0,
- help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
- )
- parser.add_argument(
- "--random_flip",
- action="store_true",
- help="whether to randomly flip videos horizontally",
- )
- parser.add_argument(
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
- )
- parser.add_argument("--num_train_epochs", type=int, default=1)
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=None,
- help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
- )
- parser.add_argument(
- "--checkpointing_steps",
- type=int,
- default=500,
- help=(
- "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
- " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
- " training using `--resume_from_checkpoint`."
- ),
- )
- parser.add_argument(
- "--checkpoints_total_limit",
- type=int,
- default=None,
- help=("Max number of checkpoints to store."),
- )
- parser.add_argument(
- "--resume_from_checkpoint",
- type=str,
- default=None,
- help=(
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
- ),
- )
- parser.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
- parser.add_argument(
- "--gradient_checkpointing",
- action="store_true",
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- parser.add_argument(
- "--scale_lr",
- action="store_true",
- default=False,
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
- )
- parser.add_argument(
- "--lr_scheduler",
- type=str,
- default="constant",
- help=(
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
- ' "constant", "constant_with_warmup"]'
- ),
- )
- parser.add_argument(
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
- )
- parser.add_argument(
- "--lr_num_cycles",
- type=int,
- default=1,
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
- )
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
- parser.add_argument(
- "--enable_slicing",
- action="store_true",
- default=False,
- help="Whether or not to use VAE slicing for saving memory.",
- )
- parser.add_argument(
- "--enable_tiling",
- action="store_true",
- default=False,
- help="Whether or not to use VAE tiling for saving memory.",
- )
- parser.add_argument(
- "--noised_image_dropout",
- type=float,
- default=0.05,
- help="Image condition dropout probability.",
- )
-
- # Optimizer
- parser.add_argument(
- "--optimizer",
- type=lambda s: s.lower(),
- default="adam",
- choices=["adam", "adamw", "prodigy"],
- help=("The optimizer type to use."),
- )
- parser.add_argument(
- "--use_8bit_adam",
- action="store_true",
- help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
- )
- parser.add_argument(
- "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
- )
- parser.add_argument(
- "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
- )
- parser.add_argument(
- "--prodigy_beta3",
- type=float,
- default=None,
- help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
- )
- parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
- parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
- parser.add_argument(
- "--adam_epsilon",
- type=float,
- default=1e-08,
- help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
- )
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
- parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.")
- parser.add_argument(
- "--prodigy_safeguard_warmup",
- action="store_true",
- help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
- )
-
- # Other information
- parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
- parser.add_argument(
- "--hub_model_id",
- type=str,
- default=None,
- help="The name of the repository to keep in sync with the local `output_dir`.",
- )
- parser.add_argument(
- "--logging_dir",
- type=str,
- default="logs",
- help="Directory where logs are stored.",
- )
- parser.add_argument(
- "--allow_tf32",
- action="store_true",
- help=(
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
- ),
- )
- parser.add_argument(
- "--report_to",
- type=str,
- default=None,
- help=(
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
- ),
- )
- parser.add_argument("--nccl_timeout", type=int, default=600, help="NCCL backend timeout in seconds.")
-
- return parser.parse_args()
-
-
-class VideoDataset(Dataset):
- def __init__(
- self,
- instance_data_root: Optional[str] = None,
- dataset_name: Optional[str] = None,
- dataset_config_name: Optional[str] = None,
- caption_column: str = "text",
- video_column: str = "video",
- height: int = 480,
- width: int = 720,
- video_reshape_mode: str = "center",
- fps: int = 8,
- max_num_frames: int = 49,
- skip_frames_start: int = 0,
- skip_frames_end: int = 0,
- cache_dir: Optional[str] = None,
- id_token: Optional[str] = None,
- ) -> None:
- super().__init__()
-
- self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
- self.dataset_name = dataset_name
- self.dataset_config_name = dataset_config_name
- self.caption_column = caption_column
- self.video_column = video_column
- self.height = height
- self.width = width
- self.video_reshape_mode = video_reshape_mode
- self.fps = fps
- self.max_num_frames = max_num_frames
- self.skip_frames_start = skip_frames_start
- self.skip_frames_end = skip_frames_end
- self.cache_dir = cache_dir
- self.id_token = id_token or ""
-
- if dataset_name is not None:
- self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
- else:
- self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
-
- self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
-
- self.num_instance_videos = len(self.instance_video_paths)
- if self.num_instance_videos != len(self.instance_prompts):
- raise ValueError(
- f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
- )
-
- self.instance_videos = self._preprocess_data()
-
- def __len__(self):
- return self.num_instance_videos
-
- def __getitem__(self, index):
- return {
- "instance_prompt": self.instance_prompts[index],
- "instance_video": self.instance_videos[index],
- }
-
- def _load_dataset_from_hub(self):
- try:
- from datasets import load_dataset
- except ImportError:
- raise ImportError(
- "You are trying to load your data using the datasets library. If you wish to train using custom "
- "captions please install the datasets library: `pip install datasets`. If you wish to load a "
- "local folder containing images only, specify --instance_data_root instead."
- )
-
- # Downloading and loading a dataset from the hub. See more about loading custom images at
- # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
- dataset = load_dataset(
- self.dataset_name,
- self.dataset_config_name,
- cache_dir=self.cache_dir,
- )
- column_names = dataset["train"].column_names
-
- if self.video_column is None:
- video_column = column_names[0]
- logger.info(f"`video_column` defaulting to {video_column}")
- else:
- video_column = self.video_column
- if video_column not in column_names:
- raise ValueError(
- f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
-
- if self.caption_column is None:
- caption_column = column_names[1]
- logger.info(f"`caption_column` defaulting to {caption_column}")
- else:
- caption_column = self.caption_column
- if self.caption_column not in column_names:
- raise ValueError(
- f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
-
- instance_prompts = dataset["train"][caption_column]
- instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]]
-
- return instance_prompts, instance_videos
-
- def _load_dataset_from_local_path(self):
- if not self.instance_data_root.exists():
- raise ValueError("Instance videos root folder does not exist")
-
- prompt_path = self.instance_data_root.joinpath(self.caption_column)
- video_path = self.instance_data_root.joinpath(self.video_column)
-
- if not prompt_path.exists() or not prompt_path.is_file():
- raise ValueError(
- "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts."
- )
- if not video_path.exists() or not video_path.is_file():
- raise ValueError(
- "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory."
- )
-
- with open(prompt_path, "r", encoding="utf-8") as file:
- instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
- with open(video_path, "r", encoding="utf-8") as file:
- instance_videos = [
- self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0
- ]
-
- if any(not path.is_file() for path in instance_videos):
- raise ValueError(
- "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file."
- )
-
- return instance_prompts, instance_videos
-
- def _resize_for_rectangle_crop(self, arr):
- image_size = self.height, self.width
- reshape_mode = self.video_reshape_mode
- if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
- arr = resize(
- arr,
- size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
- interpolation=InterpolationMode.BICUBIC,
- )
- else:
- arr = resize(
- arr,
- size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
- interpolation=InterpolationMode.BICUBIC,
- )
-
- h, w = arr.shape[2], arr.shape[3]
- arr = arr.squeeze(0)
-
- delta_h = h - image_size[0]
- delta_w = w - image_size[1]
-
- if reshape_mode == "random" or reshape_mode == "none":
- top = np.random.randint(0, delta_h + 1)
- left = np.random.randint(0, delta_w + 1)
- elif reshape_mode == "center":
- top, left = delta_h // 2, delta_w // 2
- else:
- raise NotImplementedError
- arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
- return arr
-
- def _preprocess_data(self):
- try:
- import decord
- except ImportError:
- raise ImportError(
- "The `decord` package is required for loading the video dataset. Install with `pip install decord`"
- )
-
- decord.bridge.set_bridge("torch")
-
- progress_dataset_bar = tqdm(
- range(0, len(self.instance_video_paths)),
- desc="Loading progress resize and crop videos",
- )
-
- videos = []
-
- for filename in self.instance_video_paths:
- video_reader = decord.VideoReader(uri=filename.as_posix())
- video_num_frames = len(video_reader)
-
- start_frame = min(self.skip_frames_start, video_num_frames)
- end_frame = max(0, video_num_frames - self.skip_frames_end)
- if end_frame <= start_frame:
- frames = video_reader.get_batch([start_frame])
- elif end_frame - start_frame <= self.max_num_frames:
- frames = video_reader.get_batch(list(range(start_frame, end_frame)))
- else:
- indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
- frames = video_reader.get_batch(indices)
-
- # Ensure that we don't go over the limit
- frames = frames[: self.max_num_frames]
- selected_num_frames = frames.shape[0]
-
- # Choose first (4k + 1) frames as this is how many is required by the VAE
- remainder = (3 + (selected_num_frames % 4)) % 4
- if remainder != 0:
- frames = frames[:-remainder]
- selected_num_frames = frames.shape[0]
-
- assert (selected_num_frames - 1) % 4 == 0
-
- # Training transforms
- frames = (frames - 127.5) / 127.5
- frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
- progress_dataset_bar.set_description(
- f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
- )
- frames = self._resize_for_rectangle_crop(frames)
- videos.append(frames.contiguous()) # [F, C, H, W]
- progress_dataset_bar.update(1)
-
- progress_dataset_bar.close()
-
- return videos
-
-
-def save_model_card(
- repo_id: str,
- videos=None,
- base_model: str = None,
- validation_prompt=None,
- repo_folder=None,
- fps=8,
-):
- widget_dict = []
- if videos is not None:
- for i, video in enumerate(videos):
- video_path = f"final_video_{i}.mp4"
- export_to_video(video, os.path.join(repo_folder, video_path, fps=fps))
- widget_dict.append(
- {"text": validation_prompt if validation_prompt else " ", "output": {"url": video_path}},
- )
-
- model_description = f"""
-# CogVideoX LoRA - {repo_id}
-
-
-
-## Model description
-
-These are {repo_id} LoRA weights for {base_model}.
-
-The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_image_to_video_lora.py).
-
-Was LoRA for the text encoder enabled? No.
-
-## Download model
-
-[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
-
-## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
-
-```py
-import torch
-from diffusers import CogVideoXImageToVideoPipeline
-from diffusers.utils import load_image, export_to_video
-
-pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
-pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-i2v-lora"])
-
-# The LoRA adapter weights are determined by what was used for training.
-# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
-# It can be made lower or higher from what was used in training to decrease or amplify the effect
-# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
-pipe.set_adapters(["cogvideox-i2v-lora"], [32 / 64])
-
-image = load_image("/path/to/image")
-video = pipe(image=image, "{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
-export_to_video(video, "output.mp4", fps=8)
-```
-
-For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
-
-## License
-
-Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
-"""
- model_card = load_or_create_model_card(
- repo_id_or_path=repo_id,
- from_training=True,
- license="other",
- base_model=base_model,
- prompt=validation_prompt,
- model_description=model_description,
- widget=widget_dict,
- )
- tags = [
- "image-to-video",
- "diffusers-training",
- "diffusers",
- "lora",
- "cogvideox",
- "cogvideox-diffusers",
- "template:sd-lora",
- ]
-
- model_card = populate_model_card(model_card, tags=tags)
- model_card.save(os.path.join(repo_folder, "README.md"))
-
-
-def log_validation(
- pipe,
- args,
- accelerator,
- pipeline_args,
- epoch,
- is_final_validation: bool = False,
-):
- logger.info(
- f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
- )
- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
- scheduler_args = {}
-
- if "variance_type" in pipe.scheduler.config:
- variance_type = pipe.scheduler.config.variance_type
-
- if variance_type in ["learned", "learned_range"]:
- variance_type = "fixed_small"
-
- scheduler_args["variance_type"] = variance_type
-
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
- pipe = pipe.to(accelerator.device)
- # pipe.set_progress_bar_config(disable=True)
-
- # run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
-
- videos = []
- for _ in range(args.num_validation_videos):
- pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
- pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
-
- image_np = VaeImageProcessor.pt_to_numpy(pt_images)
- image_pil = VaeImageProcessor.numpy_to_pil(image_np)
-
- videos.append(image_pil)
-
- for tracker in accelerator.trackers:
- phase_name = "test" if is_final_validation else "validation"
- if tracker.name == "wandb":
- video_filenames = []
- for i, video in enumerate(videos):
- prompt = (
- pipeline_args["prompt"][:25]
- .replace(" ", "_")
- .replace(" ", "_")
- .replace("'", "_")
- .replace('"', "_")
- .replace("/", "_")
- )
- filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
- export_to_video(video, filename, fps=8)
- video_filenames.append(filename)
-
- tracker.log(
- {
- phase_name: [
- wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
- for i, filename in enumerate(video_filenames)
- ]
- }
- )
-
- del pipe
- free_memory()
-
- return videos
-
-
-def _get_t5_prompt_embeds(
- tokenizer: T5Tokenizer,
- text_encoder: T5EncoderModel,
- prompt: Union[str, List[str]],
- num_videos_per_prompt: int = 1,
- max_sequence_length: int = 226,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- text_input_ids=None,
-):
- prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
-
- if tokenizer is not None:
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=max_sequence_length,
- truncation=True,
- add_special_tokens=True,
- return_tensors="pt",
- )
- text_input_ids = text_inputs.input_ids
- else:
- if text_input_ids is None:
- raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
-
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
-
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- _, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
-
- return prompt_embeds
-
-
-def encode_prompt(
- tokenizer: T5Tokenizer,
- text_encoder: T5EncoderModel,
- prompt: Union[str, List[str]],
- num_videos_per_prompt: int = 1,
- max_sequence_length: int = 226,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- text_input_ids=None,
-):
- prompt = [prompt] if isinstance(prompt, str) else prompt
- prompt_embeds = _get_t5_prompt_embeds(
- tokenizer,
- text_encoder,
- prompt=prompt,
- num_videos_per_prompt=num_videos_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- text_input_ids=text_input_ids,
- )
- return prompt_embeds
-
-
-def compute_prompt_embeddings(
- tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
-):
- if requires_grad:
- prompt_embeds = encode_prompt(
- tokenizer,
- text_encoder,
- prompt,
- num_videos_per_prompt=1,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
- else:
- with torch.no_grad():
- prompt_embeds = encode_prompt(
- tokenizer,
- text_encoder,
- prompt,
- num_videos_per_prompt=1,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
- return prompt_embeds
-
-
-def prepare_rotary_positional_embeddings(
- height: int,
- width: int,
- num_frames: int,
- vae_scale_factor_spatial: int = 8,
- patch_size: int = 2,
- patch_size_t: int = 1,
- attention_head_dim: int = 64,
- device: Optional[torch.device] = None,
- base_height: int = 480,
- base_width: int = 720,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- grid_height = height // (vae_scale_factor_spatial * patch_size)
- grid_width = width // (vae_scale_factor_spatial * patch_size)
- base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
- base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
-
- p_t = patch_size_t
- base_num_frames = (num_frames + p_t - 1) // p_t
-
- grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=base_num_frames,
- )
-
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
- return freqs_cos, freqs_sin
-
-
-def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
- # Use DeepSpeed optimzer
- if use_deepspeed:
- from accelerate.utils import DummyOptim
-
- return DummyOptim(
- params_to_optimize,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- eps=args.adam_epsilon,
- weight_decay=args.adam_weight_decay,
- )
-
- # Optimizer creation
- supported_optimizers = ["adam", "adamw", "prodigy"]
- if args.optimizer not in supported_optimizers:
- logger.warning(
- f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
- )
- args.optimizer = "adamw"
-
- if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]:
- logger.warning(
- f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
- f"set to {args.optimizer.lower()}"
- )
-
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
- )
-
- if args.optimizer.lower() == "adamw":
- optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
-
- optimizer = optimizer_class(
- params_to_optimize,
- betas=(args.adam_beta1, args.adam_beta2),
- eps=args.adam_epsilon,
- weight_decay=args.adam_weight_decay,
- )
- elif args.optimizer.lower() == "adam":
- optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
-
- optimizer = optimizer_class(
- params_to_optimize,
- betas=(args.adam_beta1, args.adam_beta2),
- eps=args.adam_epsilon,
- weight_decay=args.adam_weight_decay,
- )
- elif args.optimizer.lower() == "prodigy":
- try:
- import prodigyopt
- except ImportError:
- raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
-
- optimizer_class = prodigyopt.Prodigy
-
- if args.learning_rate <= 0.1:
- logger.warning(
- "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
- )
-
- optimizer = optimizer_class(
- params_to_optimize,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- beta3=args.prodigy_beta3,
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- decouple=args.prodigy_decouple,
- use_bias_correction=args.prodigy_use_bias_correction,
- safeguard_warmup=args.prodigy_safeguard_warmup,
- )
-
- return optimizer
-
-
-def main(args):
- if args.report_to == "wandb" and args.hub_token is not None:
- raise ValueError(
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
- )
-
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
- # due to pytorch#99272, MPS does not yet support bfloat16.
- raise ValueError(
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
- )
-
- logging_dir = Path(args.output_dir, args.logging_dir)
-
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
- init_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
- accelerator = Accelerator(
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- mixed_precision=args.mixed_precision,
- log_with=args.report_to,
- project_config=accelerator_project_config,
- kwargs_handlers=[ddp_kwargs, init_kwargs],
- )
-
- # Disable AMP for MPS.
- if torch.backends.mps.is_available():
- accelerator.native_amp = False
-
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
-
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- logger.info(accelerator.state, main_process_only=False)
- if accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
-
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name,
- exist_ok=True,
- ).repo_id
-
- # Prepare models and scheduler
- tokenizer = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
- )
-
- text_encoder = T5EncoderModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
- )
-
- # CogVideoX-2b weights are stored in float16
- # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
- load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
- transformer = CogVideoXTransformer3DModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="transformer",
- torch_dtype=load_dtype,
- revision=args.revision,
- variant=args.variant,
- )
-
- vae = AutoencoderKLCogVideoX.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
- )
-
- scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
-
- if args.enable_slicing:
- vae.enable_slicing()
- if args.enable_tiling:
- vae.enable_tiling()
-
- # We only train the additional adapter LoRA layers
- text_encoder.requires_grad_(False)
- transformer.requires_grad_(False)
- vae.requires_grad_(False)
-
- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
- # as these weights are only used for inference, keeping weights in full precision is not required.
- weight_dtype = torch.float32
- if accelerator.state.deepspeed_plugin:
- # DeepSpeed is handling precision, use what's in the DeepSpeed config
- if (
- "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
- and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
- ):
- weight_dtype = torch.float16
- if (
- "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
- and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
- ):
- weight_dtype = torch.float16
- else:
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
- # due to pytorch#99272, MPS does not yet support bfloat16.
- raise ValueError(
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
- )
-
- text_encoder.to(accelerator.device, dtype=weight_dtype)
- transformer.to(accelerator.device, dtype=weight_dtype)
- vae.to(accelerator.device, dtype=weight_dtype)
-
- if args.gradient_checkpointing:
- transformer.enable_gradient_checkpointing()
-
- # now we will add new LoRA weights to the attention layers
- transformer_lora_config = LoraConfig(
- r=args.rank,
- lora_alpha=args.lora_alpha,
- init_lora_weights=True,
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
- )
- transformer.add_adapter(transformer_lora_config)
-
- def unwrap_model(model):
- model = accelerator.unwrap_model(model)
- model = model._orig_mod if is_compiled_module(model) else model
- return model
-
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
- def save_model_hook(models, weights, output_dir):
- if accelerator.is_main_process:
- transformer_lora_layers_to_save = None
-
- for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
-
- # make sure to pop weight so that corresponding model is not saved again
- weights.pop()
-
- CogVideoXImageToVideoPipeline.save_lora_weights(
- output_dir,
- transformer_lora_layers=transformer_lora_layers_to_save,
- )
-
- def load_model_hook(models, input_dir):
- transformer_ = None
-
- while len(models) > 0:
- model = models.pop()
-
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- else:
- raise ValueError(f"Unexpected save model: {model.__class__}")
-
- lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)
-
- transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
- }
- transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
- incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
- if incompatible_keys is not None:
- # check only for unexpected keys
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if unexpected_keys:
- logger.warning(
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
- f" {unexpected_keys}. "
- )
-
- # Make sure the trainable params are in float32. This is again needed since the base models
- # are in `weight_dtype`. More details:
- # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
- if args.mixed_precision == "fp16":
- # only upcast trainable parameters (LoRA) into fp32
- cast_training_params([transformer_])
-
- accelerator.register_save_state_pre_hook(save_model_hook)
- accelerator.register_load_state_pre_hook(load_model_hook)
-
- # Enable TF32 for faster training on Ampere GPUs,
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if args.allow_tf32 and torch.cuda.is_available():
- torch.backends.cuda.matmul.allow_tf32 = True
-
- if args.scale_lr:
- args.learning_rate = (
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
- )
-
- # Make sure the trainable params are in float32.
- if args.mixed_precision == "fp16":
- # only upcast trainable parameters (LoRA) into fp32
- cast_training_params([transformer], dtype=torch.float32)
-
- transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
-
- # Optimization parameters
- transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
- params_to_optimize = [transformer_parameters_with_lr]
-
- use_deepspeed_optimizer = (
- accelerator.state.deepspeed_plugin is not None
- and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() != "none"
- )
- use_deepspeed_scheduler = (
- accelerator.state.deepspeed_plugin is not None
- and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() != "none"
- )
-
- optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
-
- # Dataset and DataLoader
- train_dataset = VideoDataset(
- instance_data_root=args.instance_data_root,
- dataset_name=args.dataset_name,
- dataset_config_name=args.dataset_config_name,
- caption_column=args.caption_column,
- video_column=args.video_column,
- height=args.height,
- width=args.width,
- video_reshape_mode=args.video_reshape_mode,
- fps=args.fps,
- max_num_frames=args.max_num_frames,
- skip_frames_start=args.skip_frames_start,
- skip_frames_end=args.skip_frames_end,
- cache_dir=args.cache_dir,
- id_token=args.id_token,
- )
-
- def encode_video(video, bar):
- bar.update(1)
- video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
- video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
- image = video[:, :, :1].clone()
-
- latent_dist = vae.encode(video).latent_dist
-
- image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
- image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
- noisy_image = image + torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
- image_latent_dist = vae.encode(noisy_image).latent_dist
-
- return latent_dist, image_latent_dist
-
- train_dataset.instance_prompts = [
- compute_prompt_embeddings(
- tokenizer,
- text_encoder,
- [prompt],
- transformer.config.max_text_seq_length,
- accelerator.device,
- weight_dtype,
- requires_grad=False,
- )
- for prompt in train_dataset.instance_prompts
- ]
-
- progress_encode_bar = tqdm(
- range(0, len(train_dataset.instance_videos)),
- desc="Loading Encode videos",
- )
- train_dataset.instance_videos = [encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos]
- progress_encode_bar.close()
-
- def collate_fn(examples):
- videos = []
- images = []
- for example in examples:
- latent_dist, image_latent_dist = example["instance_video"]
-
- video_latents = latent_dist.sample() * vae.config.scaling_factor
- image_latents = image_latent_dist.sample() * vae.config.scaling_factor
- video_latents = video_latents.permute(0, 2, 1, 3, 4)
- image_latents = image_latents.permute(0, 2, 1, 3, 4)
-
- padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
- latent_padding = image_latents.new_zeros(padding_shape)
- image_latents = torch.cat([image_latents, latent_padding], dim=1)
-
- if random.random() < args.noised_image_dropout:
- image_latents = torch.zeros_like(image_latents)
-
- videos.append(video_latents)
- images.append(image_latents)
-
- videos = torch.cat(videos)
- images = torch.cat(images)
- videos = videos.to(memory_format=torch.contiguous_format).float()
- images = images.to(memory_format=torch.contiguous_format).float()
-
- prompts = [example["instance_prompt"] for example in examples]
- prompts = torch.cat(prompts)
-
- return {
- "videos": (videos, images),
- "prompts": prompts,
- }
-
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=args.train_batch_size,
- shuffle=True,
- collate_fn=collate_fn,
- num_workers=args.dataloader_num_workers,
- )
-
- # Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
-
- if use_deepspeed_scheduler:
- from accelerate.utils import DummyScheduler
-
- lr_scheduler = DummyScheduler(
- optimizer=optimizer,
- total_num_steps=args.max_train_steps * accelerator.num_processes,
- warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
- )
- else:
- lr_scheduler = get_scheduler(
- args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
- num_cycles=args.lr_num_cycles,
- power=args.lr_power,
- )
-
- # Prepare everything with our `accelerator`.
- transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- transformer, optimizer, train_dataloader, lr_scheduler
- )
-
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- # Afterwards we recalculate our number of training epochs
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-
- # We need to initialize the trackers we use, and also store our configuration.
- # The trackers initializes automatically on the main process.
- if accelerator.is_main_process:
- tracker_name = args.tracker_name or "cogvideox-i2v-lora"
- accelerator.init_trackers(tracker_name, config=vars(args))
-
- # Train!
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
- num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
-
- logger.info("***** Running training *****")
- logger.info(f" Num trainable parameters = {num_trainable_parameters}")
- logger.info(f" Num examples = {len(train_dataset)}")
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
- logger.info(f" Num epochs = {args.num_train_epochs}")
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {args.max_train_steps}")
- global_step = 0
- first_epoch = 0
-
- # Potentially load in the weights and states from a previous save
- if not args.resume_from_checkpoint:
- initial_global_step = 0
- else:
- if args.resume_from_checkpoint != "latest":
- path = os.path.basename(args.resume_from_checkpoint)
- else:
- # Get the mos recent checkpoint
- dirs = os.listdir(args.output_dir)
- dirs = [d for d in dirs if d.startswith("checkpoint")]
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
- path = dirs[-1] if len(dirs) > 0 else None
-
- if path is None:
- accelerator.print(
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
- )
- args.resume_from_checkpoint = None
- initial_global_step = 0
- else:
- accelerator.print(f"Resuming from checkpoint {path}")
- accelerator.load_state(os.path.join(args.output_dir, path))
- global_step = int(path.split("-")[1])
-
- initial_global_step = global_step
- first_epoch = global_step // num_update_steps_per_epoch
-
- progress_bar = tqdm(
- range(0, args.max_train_steps),
- initial=initial_global_step,
- desc="Steps",
- # Only show the progress bar once on each machine.
- disable=not accelerator.is_local_main_process,
- )
- vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
-
- # For DeepSpeed training
- model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
-
- for epoch in range(first_epoch, args.num_train_epochs):
- transformer.train()
-
- for step, batch in enumerate(train_dataloader):
- models_to_accumulate = [transformer]
-
- with accelerator.accumulate(models_to_accumulate):
- video_latents, image_latents = batch["videos"]
- prompt_embeds = batch["prompts"]
-
- video_latents = video_latents.to(dtype=weight_dtype) # [B, F, C, H, W]
- image_latents = image_latents.to(dtype=weight_dtype) # [B, F, C, H, W]
-
- batch_size, num_frames, num_channels, height, width = video_latents.shape
-
- # Sample a random timestep for each image
- timesteps = torch.randint(
- 0, scheduler.config.num_train_timesteps, (batch_size,), device=video_latents.device
- )
- timesteps = timesteps.long()
-
- # Sample noise that will be added to the latents
- noise = torch.randn_like(video_latents)
-
- # Add noise to the model input according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
- noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
-
- # Prepare rotary embeds
- image_rotary_emb = (
- prepare_rotary_positional_embeddings(
- height=args.height,
- width=args.width,
- num_frames=num_frames,
- vae_scale_factor_spatial=vae_scale_factor_spatial,
- patch_size=model_config.patch_size,
- patch_size_t=model_config.patch_size_t,
- attention_head_dim=model_config.attention_head_dim,
- device=accelerator.device,
- )
- if model_config.use_rotary_positional_embeddings
- else None
- )
-
- # Predict the noise residual
- model_output = transformer(
- hidden_states=noisy_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timesteps,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- )[0]
- model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
-
- alphas_cumprod = scheduler.alphas_cumprod[timesteps]
- weights = 1 / (1 - alphas_cumprod)
- while len(weights.shape) < len(model_pred.shape):
- weights = weights.unsqueeze(-1)
-
- target = video_latents
-
- loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
- loss = loss.mean()
- accelerator.backward(loss)
-
- if accelerator.sync_gradients:
- params_to_clip = transformer.parameters()
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
-
- if accelerator.state.deepspeed_plugin is None:
- optimizer.step()
- optimizer.zero_grad()
-
- lr_scheduler.step()
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- if accelerator.is_main_process:
- if global_step % args.checkpointing_steps == 0:
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
- if args.checkpoints_total_limit is not None:
- checkpoints = os.listdir(args.output_dir)
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
-
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
- if len(checkpoints) >= args.checkpoints_total_limit:
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
- removing_checkpoints = checkpoints[0:num_to_remove]
-
- logger.info(
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
- )
- logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
-
- for removing_checkpoint in removing_checkpoints:
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
- shutil.rmtree(removing_checkpoint)
-
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
- accelerator.save_state(save_path)
- logger.info(f"Saved state to {save_path}")
-
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
- accelerator.log(logs, step=global_step)
-
- if global_step >= args.max_train_steps:
- break
-
- if accelerator.is_main_process:
- if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
- # Create pipeline
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- transformer=unwrap_model(transformer),
- scheduler=scheduler,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
- )
-
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
- validation_images = args.validation_images.split(args.validation_prompt_separator)
-
- for validation_image, validation_prompt in zip(validation_images, validation_prompts):
- pipeline_args = {
- "image": load_image(validation_image),
- "prompt": validation_prompt,
- "guidance_scale": args.guidance_scale,
- "use_dynamic_cfg": args.use_dynamic_cfg,
- "height": args.height,
- "width": args.width,
- }
-
- validation_outputs = log_validation(
- pipe=pipe,
- args=args,
- accelerator=accelerator,
- pipeline_args=pipeline_args,
- epoch=epoch,
- )
-
- # Save the lora layers
- accelerator.wait_for_everyone()
- if accelerator.is_main_process:
- transformer = unwrap_model(transformer)
- dtype = (
- torch.float16
- if args.mixed_precision == "fp16"
- else torch.bfloat16
- if args.mixed_precision == "bf16"
- else torch.float32
- )
- transformer = transformer.to(dtype)
- transformer_lora_layers = get_peft_model_state_dict(transformer)
-
- CogVideoXImageToVideoPipeline.save_lora_weights(
- save_directory=args.output_dir,
- transformer_lora_layers=transformer_lora_layers,
- )
-
- # Cleanup trained models to save memory
- del transformer
- free_memory()
-
- # Final test inference
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
- )
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
-
- if args.enable_slicing:
- pipe.vae.enable_slicing()
- if args.enable_tiling:
- pipe.vae.enable_tiling()
-
- # Load LoRA weights
- lora_scaling = args.lora_alpha / args.rank
- pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-i2v-lora")
- pipe.set_adapters(["cogvideox-i2v-lora"], [lora_scaling])
-
- # Run inference
- validation_outputs = []
- if args.validation_prompt and args.num_validation_videos > 0:
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
- validation_images = args.validation_images.split(args.validation_prompt_separator)
-
- for validation_image, validation_prompt in zip(validation_images, validation_prompts):
- pipeline_args = {
- "image": load_image(validation_image),
- "prompt": validation_prompt,
- "guidance_scale": args.guidance_scale,
- "use_dynamic_cfg": args.use_dynamic_cfg,
- "height": args.height,
- "width": args.width,
- }
-
- video = log_validation(
- pipe=pipe,
- args=args,
- accelerator=accelerator,
- pipeline_args=pipeline_args,
- epoch=epoch,
- is_final_validation=True,
- )
- validation_outputs.extend(video)
-
- if args.push_to_hub:
- validation_prompt = args.validation_prompt or ""
- validation_prompt = validation_prompt.split(args.validation_prompt_separator)[0]
- save_model_card(
- repo_id,
- videos=validation_outputs,
- base_model=args.pretrained_model_name_or_path,
- validation_prompt=validation_prompt,
- repo_folder=args.output_dir,
- fps=args.fps,
- )
- upload_folder(
- repo_id=repo_id,
- folder_path=args.output_dir,
- commit_message="End of training",
- ignore_patterns=["step_*", "epoch_*"],
- )
-
- accelerator.end_training()
-
-
-if __name__ == "__main__":
- args = get_args()
- main(args)
diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py
deleted file mode 100644
index e12b3d5..0000000
--- a/finetune/train_cogvideox_lora.py
+++ /dev/null
@@ -1,1573 +0,0 @@
-# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import argparse
-import logging
-import math
-import os
-import shutil
-from pathlib import Path
-from typing import List, Optional, Tuple, Union
-
-import torch
-import transformers
-from accelerate import Accelerator
-from accelerate.logging import get_logger
-from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
-from huggingface_hub import create_repo, upload_folder
-from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
-from torch.utils.data import DataLoader, Dataset
-from torchvision import transforms
-from tqdm.auto import tqdm
-from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
-
-import diffusers
-from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
-from diffusers.models.embeddings import get_3d_rotary_pos_embed
-from diffusers.optimization import get_scheduler
-from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
-from diffusers.training_utils import (
- cast_training_params,
- free_memory,
-)
-from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
-from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
-from diffusers.utils.torch_utils import is_compiled_module
-
-
-if is_wandb_available():
- import wandb
-
-# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
-
-logger = get_logger(__name__)
-
-
-def get_args():
- parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
-
- # Model information
- parser.add_argument(
- "--pretrained_model_name_or_path",
- type=str,
- default=None,
- required=True,
- help="Path to pretrained model or model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--revision",
- type=str,
- default=None,
- required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
- parser.add_argument(
- "--cache_dir",
- type=str,
- default=None,
- help="The directory where the downloaded models and datasets will be stored.",
- )
-
- # Dataset information
- parser.add_argument(
- "--dataset_name",
- type=str,
- default=None,
- help=(
- "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
- " or to a folder containing files that 🤗 Datasets can understand."
- ),
- )
- parser.add_argument(
- "--dataset_config_name",
- type=str,
- default=None,
- help="The config of the Dataset, leave as None if there's only one config.",
- )
- parser.add_argument(
- "--instance_data_root",
- type=str,
- default=None,
- help=("A folder containing the training data."),
- )
- parser.add_argument(
- "--video_column",
- type=str,
- default="video",
- help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.",
- )
- parser.add_argument(
- "--caption_column",
- type=str,
- default="text",
- help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.",
- )
- parser.add_argument(
- "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided."
- )
- parser.add_argument(
- "--dataloader_num_workers",
- type=int,
- default=0,
- help=(
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
- ),
- )
-
- # Validation
- parser.add_argument(
- "--validation_prompt",
- type=str,
- default=None,
- help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
- )
- parser.add_argument(
- "--validation_prompt_separator",
- type=str,
- default=":::",
- help="String that separates multiple validation prompts",
- )
- parser.add_argument(
- "--num_validation_videos",
- type=int,
- default=1,
- help="Number of videos that should be generated during validation per `validation_prompt`.",
- )
- parser.add_argument(
- "--validation_epochs",
- type=int,
- default=50,
- help=(
- "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`."
- ),
- )
- parser.add_argument(
- "--guidance_scale",
- type=float,
- default=6,
- help="The guidance scale to use while sampling validation videos.",
- )
- parser.add_argument(
- "--use_dynamic_cfg",
- action="store_true",
- default=False,
- help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
- )
-
- # Training information
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
- parser.add_argument(
- "--rank",
- type=int,
- default=128,
- help=("The dimension of the LoRA update matrices."),
- )
- parser.add_argument(
- "--lora_alpha",
- type=float,
- default=128,
- help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"),
- )
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default=None,
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default="cogvideox-lora",
- help="The output directory where the model predictions and checkpoints will be written.",
- )
- parser.add_argument(
- "--height",
- type=int,
- default=480,
- help="All input videos are resized to this height.",
- )
- parser.add_argument(
- "--width",
- type=int,
- default=720,
- help="All input videos are resized to this width.",
- )
- parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
- parser.add_argument(
- "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
- )
- parser.add_argument(
- "--skip_frames_start",
- type=int,
- default=0,
- help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.",
- )
- parser.add_argument(
- "--skip_frames_end",
- type=int,
- default=0,
- help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.",
- )
- parser.add_argument(
- "--random_flip",
- action="store_true",
- help="whether to randomly flip videos horizontally",
- )
- parser.add_argument(
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
- )
- parser.add_argument("--num_train_epochs", type=int, default=1)
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=None,
- help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
- )
- parser.add_argument(
- "--checkpointing_steps",
- type=int,
- default=500,
- help=(
- "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
- " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
- " training using `--resume_from_checkpoint`."
- ),
- )
- parser.add_argument(
- "--checkpoints_total_limit",
- type=int,
- default=None,
- help=("Max number of checkpoints to store."),
- )
- parser.add_argument(
- "--resume_from_checkpoint",
- type=str,
- default=None,
- help=(
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
- ),
- )
- parser.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
- parser.add_argument(
- "--gradient_checkpointing",
- action="store_true",
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- parser.add_argument(
- "--scale_lr",
- action="store_true",
- default=False,
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
- )
- parser.add_argument(
- "--lr_scheduler",
- type=str,
- default="constant",
- help=(
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
- ' "constant", "constant_with_warmup"]'
- ),
- )
- parser.add_argument(
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
- )
- parser.add_argument(
- "--lr_num_cycles",
- type=int,
- default=1,
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
- )
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
- parser.add_argument(
- "--enable_slicing",
- action="store_true",
- default=False,
- help="Whether or not to use VAE slicing for saving memory.",
- )
- parser.add_argument(
- "--enable_tiling",
- action="store_true",
- default=False,
- help="Whether or not to use VAE tiling for saving memory.",
- )
-
- # Optimizer
- parser.add_argument(
- "--optimizer",
- type=lambda s: s.lower(),
- default="adam",
- choices=["adam", "adamw", "prodigy"],
- help=("The optimizer type to use."),
- )
- parser.add_argument(
- "--use_8bit_adam",
- action="store_true",
- help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
- )
- parser.add_argument(
- "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
- )
- parser.add_argument(
- "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
- )
- parser.add_argument(
- "--prodigy_beta3",
- type=float,
- default=None,
- help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
- )
- parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
- parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
- parser.add_argument(
- "--adam_epsilon",
- type=float,
- default=1e-08,
- help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
- )
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
- parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.")
- parser.add_argument(
- "--prodigy_safeguard_warmup",
- action="store_true",
- help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
- )
-
- # Other information
- parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
- parser.add_argument(
- "--hub_model_id",
- type=str,
- default=None,
- help="The name of the repository to keep in sync with the local `output_dir`.",
- )
- parser.add_argument(
- "--logging_dir",
- type=str,
- default="logs",
- help="Directory where logs are stored.",
- )
- parser.add_argument(
- "--allow_tf32",
- action="store_true",
- help=(
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
- ),
- )
- parser.add_argument(
- "--report_to",
- type=str,
- default=None,
- help=(
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
- ),
- )
-
- return parser.parse_args()
-
-
-class VideoDataset(Dataset):
- def __init__(
- self,
- instance_data_root: Optional[str] = None,
- dataset_name: Optional[str] = None,
- dataset_config_name: Optional[str] = None,
- caption_column: str = "text",
- video_column: str = "video",
- height: int = 480,
- width: int = 720,
- fps: int = 8,
- max_num_frames: int = 49,
- skip_frames_start: int = 0,
- skip_frames_end: int = 0,
- cache_dir: Optional[str] = None,
- id_token: Optional[str] = None,
- ) -> None:
- super().__init__()
-
- self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
- self.dataset_name = dataset_name
- self.dataset_config_name = dataset_config_name
- self.caption_column = caption_column
- self.video_column = video_column
- self.height = height
- self.width = width
- self.fps = fps
- self.max_num_frames = max_num_frames
- self.skip_frames_start = skip_frames_start
- self.skip_frames_end = skip_frames_end
- self.cache_dir = cache_dir
- self.id_token = id_token or ""
-
- if dataset_name is not None:
- self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
- else:
- self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
-
- self.num_instance_videos = len(self.instance_video_paths)
- if self.num_instance_videos != len(self.instance_prompts):
- raise ValueError(
- f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
- )
-
- self.instance_videos = self._preprocess_data()
-
- def __len__(self):
- return self.num_instance_videos
-
- def __getitem__(self, index):
- return {
- "instance_prompt": self.id_token + self.instance_prompts[index],
- "instance_video": self.instance_videos[index],
- }
-
- def _load_dataset_from_hub(self):
- try:
- from datasets import load_dataset
- except ImportError:
- raise ImportError(
- "You are trying to load your data using the datasets library. If you wish to train using custom "
- "captions please install the datasets library: `pip install datasets`. If you wish to load a "
- "local folder containing images only, specify --instance_data_root instead."
- )
-
- # Downloading and loading a dataset from the hub. See more about loading custom images at
- # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
- dataset = load_dataset(
- self.dataset_name,
- self.dataset_config_name,
- cache_dir=self.cache_dir,
- )
- column_names = dataset["train"].column_names
-
- if self.video_column is None:
- video_column = column_names[0]
- logger.info(f"`video_column` defaulting to {video_column}")
- else:
- video_column = self.video_column
- if video_column not in column_names:
- raise ValueError(
- f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
-
- if self.caption_column is None:
- caption_column = column_names[1]
- logger.info(f"`caption_column` defaulting to {caption_column}")
- else:
- caption_column = self.caption_column
- if self.caption_column not in column_names:
- raise ValueError(
- f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
-
- instance_prompts = dataset["train"][caption_column]
- instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]]
-
- return instance_prompts, instance_videos
-
- def _load_dataset_from_local_path(self):
- if not self.instance_data_root.exists():
- raise ValueError("Instance videos root folder does not exist")
-
- prompt_path = self.instance_data_root.joinpath(self.caption_column)
- video_path = self.instance_data_root.joinpath(self.video_column)
-
- if not prompt_path.exists() or not prompt_path.is_file():
- raise ValueError(
- "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts."
- )
- if not video_path.exists() or not video_path.is_file():
- raise ValueError(
- "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory."
- )
-
- with open(prompt_path, "r", encoding="utf-8") as file:
- instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
- with open(video_path, "r", encoding="utf-8") as file:
- instance_videos = [
- self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0
- ]
-
- if any(not path.is_file() for path in instance_videos):
- raise ValueError(
- "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file."
- )
-
- return instance_prompts, instance_videos
-
- def _preprocess_data(self):
- try:
- import decord
- except ImportError:
- raise ImportError(
- "The `decord` package is required for loading the video dataset. Install with `pip install decord`"
- )
-
- decord.bridge.set_bridge("torch")
-
- videos = []
- train_transforms = transforms.Compose(
- [
- transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
- ]
- )
-
- for filename in self.instance_video_paths:
- video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
- video_num_frames = len(video_reader)
-
- start_frame = min(self.skip_frames_start, video_num_frames)
- end_frame = max(0, video_num_frames - self.skip_frames_end)
- if end_frame <= start_frame:
- frames = video_reader.get_batch([start_frame])
- elif end_frame - start_frame <= self.max_num_frames:
- frames = video_reader.get_batch(list(range(start_frame, end_frame)))
- else:
- indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
- frames = video_reader.get_batch(indices)
-
- # Ensure that we don't go over the limit
- frames = frames[: self.max_num_frames]
- selected_num_frames = frames.shape[0]
-
- # Choose first (4k + 1) frames as this is how many is required by the VAE
- remainder = (3 + (selected_num_frames % 4)) % 4
- if remainder != 0:
- frames = frames[:-remainder]
- selected_num_frames = frames.shape[0]
-
- assert (selected_num_frames - 1) % 4 == 0
-
- # Training transforms
- frames = frames.float()
- frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
- videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
-
- return videos
-
-
-def save_model_card(
- repo_id: str,
- videos=None,
- base_model: str = None,
- validation_prompt=None,
- repo_folder=None,
- fps=8,
-):
- widget_dict = []
- if videos is not None:
- for i, video in enumerate(videos):
- export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
- widget_dict.append(
- {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}}
- )
-
- model_description = f"""
-# CogVideoX LoRA - {repo_id}
-
-
-
-## Model description
-
-These are {repo_id} LoRA weights for {base_model}.
-
-The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
-
-Was LoRA for the text encoder enabled? No.
-
-## Download model
-
-[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
-
-## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
-
-```py
-from diffusers import CogVideoXPipeline
-import torch
-
-pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
-pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
-
-# The LoRA adapter weights are determined by what was used for training.
-# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
-# It can be made lower or higher from what was used in training to decrease or amplify the effect
-# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
-pipe.set_adapters(["cogvideox-lora"], [32 / 64])
-
-video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
-```
-
-For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
-
-## License
-
-Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).
-"""
- model_card = load_or_create_model_card(
- repo_id_or_path=repo_id,
- from_training=True,
- license="other",
- base_model=base_model,
- prompt=validation_prompt,
- model_description=model_description,
- widget=widget_dict,
- )
- tags = [
- "text-to-video",
- "diffusers-training",
- "diffusers",
- "lora",
- "cogvideox",
- "cogvideox-diffusers",
- "template:sd-lora",
- ]
-
- model_card = populate_model_card(model_card, tags=tags)
- model_card.save(os.path.join(repo_folder, "README.md"))
-
-
-def log_validation(
- pipe,
- args,
- accelerator,
- pipeline_args,
- epoch,
- is_final_validation: bool = False,
-):
- logger.info(
- f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
- )
- # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
- scheduler_args = {}
-
- if "variance_type" in pipe.scheduler.config:
- variance_type = pipe.scheduler.config.variance_type
-
- if variance_type in ["learned", "learned_range"]:
- variance_type = "fixed_small"
-
- scheduler_args["variance_type"] = variance_type
-
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
- pipe = pipe.to(accelerator.device)
- # pipe.set_progress_bar_config(disable=True)
-
- # run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
-
- videos = []
- for _ in range(args.num_validation_videos):
- video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
- videos.append(video)
-
- for tracker in accelerator.trackers:
- phase_name = "test" if is_final_validation else "validation"
- if tracker.name == "wandb":
- video_filenames = []
- for i, video in enumerate(videos):
- prompt = (
- pipeline_args["prompt"][:25]
- .replace(" ", "_")
- .replace(" ", "_")
- .replace("'", "_")
- .replace('"', "_")
- .replace("/", "_")
- )
- filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
- export_to_video(video, filename, fps=8)
- video_filenames.append(filename)
-
- tracker.log(
- {
- phase_name: [
- wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
- for i, filename in enumerate(video_filenames)
- ]
- }
- )
-
- free_memory()
-
- return videos
-
-
-def _get_t5_prompt_embeds(
- tokenizer: T5Tokenizer,
- text_encoder: T5EncoderModel,
- prompt: Union[str, List[str]],
- num_videos_per_prompt: int = 1,
- max_sequence_length: int = 226,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- text_input_ids=None,
-):
- prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
-
- if tokenizer is not None:
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=max_sequence_length,
- truncation=True,
- add_special_tokens=True,
- return_tensors="pt",
- )
- text_input_ids = text_inputs.input_ids
- else:
- if text_input_ids is None:
- raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
-
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
-
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- _, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
-
- return prompt_embeds
-
-
-def encode_prompt(
- tokenizer: T5Tokenizer,
- text_encoder: T5EncoderModel,
- prompt: Union[str, List[str]],
- num_videos_per_prompt: int = 1,
- max_sequence_length: int = 226,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- text_input_ids=None,
-):
- prompt = [prompt] if isinstance(prompt, str) else prompt
- prompt_embeds = _get_t5_prompt_embeds(
- tokenizer,
- text_encoder,
- prompt=prompt,
- num_videos_per_prompt=num_videos_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- text_input_ids=text_input_ids,
- )
- return prompt_embeds
-
-
-def compute_prompt_embeddings(
- tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
-):
- if requires_grad:
- prompt_embeds = encode_prompt(
- tokenizer,
- text_encoder,
- prompt,
- num_videos_per_prompt=1,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
- else:
- with torch.no_grad():
- prompt_embeds = encode_prompt(
- tokenizer,
- text_encoder,
- prompt,
- num_videos_per_prompt=1,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
- return prompt_embeds
-
-
-def prepare_rotary_positional_embeddings(
- height: int,
- width: int,
- num_frames: int,
- vae_scale_factor_spatial: int = 8,
- patch_size: int = 2,
- patch_size_t: int = 1,
- attention_head_dim: int = 64,
- device: Optional[torch.device] = None,
- base_height: int = 480,
- base_width: int = 720,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- grid_height = height // (vae_scale_factor_spatial * patch_size)
- grid_width = width // (vae_scale_factor_spatial * patch_size)
- base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
- base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
-
- p_t = patch_size_t
- base_num_frames = (num_frames + p_t - 1) // p_t
-
- grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=base_num_frames,
- )
-
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
- return freqs_cos, freqs_sin
-
-
-def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
- # Use DeepSpeed optimzer
- if use_deepspeed:
- from accelerate.utils import DummyOptim
-
- return DummyOptim(
- params_to_optimize,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- eps=args.adam_epsilon,
- weight_decay=args.adam_weight_decay,
- )
-
- # Optimizer creation
- supported_optimizers = ["adam", "adamw", "prodigy"]
- if args.optimizer not in supported_optimizers:
- logger.warning(
- f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
- )
- args.optimizer = "adamw"
-
- if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
- logger.warning(
- f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
- f"set to {args.optimizer.lower()}"
- )
-
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
- )
-
- if args.optimizer.lower() == "adamw":
- optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
-
- optimizer = optimizer_class(
- params_to_optimize,
- betas=(args.adam_beta1, args.adam_beta2),
- eps=args.adam_epsilon,
- weight_decay=args.adam_weight_decay,
- )
- elif args.optimizer.lower() == "adam":
- optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
-
- optimizer = optimizer_class(
- params_to_optimize,
- betas=(args.adam_beta1, args.adam_beta2),
- eps=args.adam_epsilon,
- weight_decay=args.adam_weight_decay,
- )
- elif args.optimizer.lower() == "prodigy":
- try:
- import prodigyopt
- except ImportError:
- raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
-
- optimizer_class = prodigyopt.Prodigy
-
- if args.learning_rate <= 0.1:
- logger.warning(
- "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
- )
-
- optimizer = optimizer_class(
- params_to_optimize,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- beta3=args.prodigy_beta3,
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- decouple=args.prodigy_decouple,
- use_bias_correction=args.prodigy_use_bias_correction,
- safeguard_warmup=args.prodigy_safeguard_warmup,
- )
-
- return optimizer
-
-
-def main(args):
- if args.report_to == "wandb" and args.hub_token is not None:
- raise ValueError(
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
- )
-
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
- # due to pytorch#99272, MPS does not yet support bfloat16.
- raise ValueError(
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
- )
-
- logging_dir = Path(args.output_dir, args.logging_dir)
-
- expected_midxed_precision = "bf16" if "5b" in args.pretrained_model_name_or_path.lower() else "fp16"
- if args.mixed_precision != expected_midxed_precision:
- raise ValueError(f"Mixed precision {args.mixed_precision} does not match the model precision, should be {expected_midxed_precision}")
-
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
- kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
- accelerator = Accelerator(
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- mixed_precision=args.mixed_precision,
- log_with=args.report_to,
- project_config=accelerator_project_config,
- kwargs_handlers=[kwargs],
- )
-
- if accelerator.state.deepspeed_plugin:
- # Set deepspeed config according to args
- config = {
- 'optimizer': {
- 'type': args.optimizer,
- 'params': {
- 'lr': args.learning_rate,
- 'betas': [args.adam_beta1, args.adam_beta2]
- },
- 'torch_adam': True
- },
- 'bf16': {
- 'enabled': True if args.mixed_precision == "bf16" else False
- },
- 'fp16': {
- 'enabled': True if args.mixed_precision == "fp16" else False
- },
- 'gradient_accumulation_steps': args.gradient_accumulation_steps,
- 'train_batch_size': args.train_batch_size
- }
- accelerator.state.deepspeed_plugin.deepspeed_config.update(config)
-
- # Disable AMP for MPS.
- if torch.backends.mps.is_available():
- accelerator.native_amp = False
-
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
-
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- logger.info(accelerator.state, main_process_only=False)
- if accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
-
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name,
- exist_ok=True,
- ).repo_id
-
- # Prepare models and scheduler
- tokenizer = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
- )
-
- text_encoder = T5EncoderModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
- )
-
- # CogVideoX-2b weights are stored in float16
- # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
- load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
- transformer = CogVideoXTransformer3DModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="transformer",
- torch_dtype=load_dtype,
- revision=args.revision,
- variant=args.variant,
- )
-
- vae = AutoencoderKLCogVideoX.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
- )
-
- scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
-
- if args.enable_slicing:
- vae.enable_slicing()
- if args.enable_tiling:
- vae.enable_tiling()
-
- # We only train the additional adapter LoRA layers
- text_encoder.requires_grad_(False)
- transformer.requires_grad_(False)
- vae.requires_grad_(False)
-
- # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
- # as these weights are only used for inference, keeping weights in full precision is not required.
- weight_dtype = torch.float32
- if accelerator.state.deepspeed_plugin:
- # DeepSpeed is handling precision, use what's in the DeepSpeed config
- if (
- "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
- and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
- ):
- weight_dtype = torch.float16
- if (
- "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
- and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
- ):
- weight_dtype = torch.bfloat16
- else:
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
- # due to pytorch#99272, MPS does not yet support bfloat16.
- raise ValueError(
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
- )
-
- text_encoder.to(accelerator.device, dtype=weight_dtype)
- transformer.to(accelerator.device, dtype=weight_dtype)
- vae.to(accelerator.device, dtype=weight_dtype)
-
- if args.gradient_checkpointing:
- transformer.enable_gradient_checkpointing()
-
- # now we will add new LoRA weights to the attention layers
- transformer_lora_config = LoraConfig(
- r=args.rank,
- lora_alpha=args.lora_alpha,
- init_lora_weights=True,
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
- )
- transformer.add_adapter(transformer_lora_config)
-
- def unwrap_model(model):
- model = accelerator.unwrap_model(model)
- model = model._orig_mod if is_compiled_module(model) else model
- return model
-
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
- def save_model_hook(models, weights, output_dir):
- if accelerator.is_main_process:
- transformer_lora_layers_to_save = None
-
- for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
-
- # make sure to pop weight so that corresponding model is not saved again
- weights.pop()
-
- CogVideoXPipeline.save_lora_weights(
- output_dir,
- transformer_lora_layers=transformer_lora_layers_to_save,
- )
-
- def load_model_hook(models, input_dir):
- transformer_ = None
-
- while len(models) > 0:
- model = models.pop()
-
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- else:
- raise ValueError(f"Unexpected save model: {model.__class__}")
-
- lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
-
- transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
- }
- transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
- incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
- if incompatible_keys is not None:
- # check only for unexpected keys
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if unexpected_keys:
- logger.warning(
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
- f" {unexpected_keys}. "
- )
-
- # Make sure the trainable params are in float32. This is again needed since the base models
- # are in `weight_dtype`. More details:
- # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
- if args.mixed_precision == "fp16":
- # only upcast trainable parameters (LoRA) into fp32
- cast_training_params([transformer_])
-
- accelerator.register_save_state_pre_hook(save_model_hook)
- accelerator.register_load_state_pre_hook(load_model_hook)
-
- # Enable TF32 for faster training on Ampere GPUs,
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if args.allow_tf32 and torch.cuda.is_available():
- torch.backends.cuda.matmul.allow_tf32 = True
-
- if args.scale_lr:
- args.learning_rate = (
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
- )
-
- # Make sure the trainable params are in float32.
- if args.mixed_precision == "fp16":
- # only upcast trainable parameters (LoRA) into fp32
- cast_training_params([transformer], dtype=torch.float32)
-
- transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
-
- # Optimization parameters
- transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
- params_to_optimize = [transformer_parameters_with_lr]
-
- use_deepspeed_optimizer = (
- accelerator.state.deepspeed_plugin is not None
- and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
- )
- use_deepspeed_scheduler = (
- accelerator.state.deepspeed_plugin is not None
- and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
- )
-
- optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
-
- # Dataset and DataLoader
- train_dataset = VideoDataset(
- instance_data_root=args.instance_data_root,
- dataset_name=args.dataset_name,
- dataset_config_name=args.dataset_config_name,
- caption_column=args.caption_column,
- video_column=args.video_column,
- height=args.height,
- width=args.width,
- fps=args.fps,
- max_num_frames=args.max_num_frames,
- skip_frames_start=args.skip_frames_start,
- skip_frames_end=args.skip_frames_end,
- cache_dir=args.cache_dir,
- id_token=args.id_token,
- )
-
- def encode_video(video):
- video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
- video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
- latent_dist = vae.encode(video).latent_dist
- return latent_dist
-
- train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
-
- def collate_fn(examples):
- videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
- prompts = [example["instance_prompt"] for example in examples]
-
- videos = torch.cat(videos)
- videos = videos.to(memory_format=torch.contiguous_format).float()
-
- return {
- "videos": videos,
- "prompts": prompts,
- }
-
- train_dataloader = DataLoader(
- train_dataset,
- batch_size=args.train_batch_size,
- shuffle=True,
- collate_fn=collate_fn,
- num_workers=args.dataloader_num_workers,
- )
-
- # Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
-
- if use_deepspeed_scheduler:
- from accelerate.utils import DummyScheduler
-
- lr_scheduler = DummyScheduler(
- optimizer=optimizer,
- total_num_steps=args.max_train_steps * accelerator.num_processes,
- warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
- )
- else:
- lr_scheduler = get_scheduler(
- args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
- num_cycles=args.lr_num_cycles,
- power=args.lr_power,
- )
-
- # Prepare everything with our `accelerator`.
- transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- transformer, optimizer, train_dataloader, lr_scheduler
- )
-
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- # Afterwards we recalculate our number of training epochs
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-
- # We need to initialize the trackers we use, and also store our configuration.
- # The trackers initializes automatically on the main process.
- if accelerator.is_main_process:
- tracker_name = args.tracker_name or "cogvideox-lora"
- accelerator.init_trackers(tracker_name, config=vars(args))
-
- # Train!
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
- num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
-
- logger.info("***** Running training *****")
- logger.info(f" Num trainable parameters = {num_trainable_parameters}")
- logger.info(f" Num examples = {len(train_dataset)}")
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
- logger.info(f" Num epochs = {args.num_train_epochs}")
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {args.max_train_steps}")
- global_step = 0
- first_epoch = 0
-
- # Potentially load in the weights and states from a previous save
- if not args.resume_from_checkpoint:
- initial_global_step = 0
- else:
- if args.resume_from_checkpoint != "latest":
- path = os.path.basename(args.resume_from_checkpoint)
- else:
- # Get the mos recent checkpoint
- dirs = os.listdir(args.output_dir)
- dirs = [d for d in dirs if d.startswith("checkpoint")]
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
- path = dirs[-1] if len(dirs) > 0 else None
-
- if path is None:
- accelerator.print(
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
- )
- args.resume_from_checkpoint = None
- initial_global_step = 0
- else:
- accelerator.print(f"Resuming from checkpoint {path}")
- accelerator.load_state(os.path.join(args.output_dir, path))
- global_step = int(path.split("-")[1])
-
- initial_global_step = global_step
- first_epoch = global_step // num_update_steps_per_epoch
-
- progress_bar = tqdm(
- range(0, args.max_train_steps),
- initial=initial_global_step,
- desc="Steps",
- # Only show the progress bar once on each machine.
- disable=not accelerator.is_local_main_process,
- )
- vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
-
- # For DeepSpeed training
- model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
-
- for epoch in range(first_epoch, args.num_train_epochs):
- transformer.train()
-
- for step, batch in enumerate(train_dataloader):
- models_to_accumulate = [transformer]
-
- with accelerator.accumulate(models_to_accumulate):
- model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
- prompts = batch["prompts"]
-
- # encode prompts
- prompt_embeds = compute_prompt_embeddings(
- tokenizer,
- text_encoder,
- prompts,
- model_config.max_text_seq_length,
- accelerator.device,
- weight_dtype,
- requires_grad=False,
- )
-
- # Sample noise that will be added to the latents
- noise = torch.randn_like(model_input)
- batch_size, num_frames, num_channels, height, width = model_input.shape
-
- # Sample a random timestep for each image
- timesteps = torch.randint(
- 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
- )
- timesteps = timesteps.long()
-
- # Prepare rotary embeds
- image_rotary_emb = (
- prepare_rotary_positional_embeddings(
- height=args.height,
- width=args.width,
- num_frames=num_frames,
- vae_scale_factor_spatial=vae_scale_factor_spatial,
- patch_size=model_config.patch_size,
- patch_size_t=model_config.patch_size_t if model_config.patch_size_t is not None else 1,
- attention_head_dim=model_config.attention_head_dim,
- device=accelerator.device,
- )
- if model_config.use_rotary_positional_embeddings
- else None
- )
-
- # Add noise to the model input according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
-
- # Predict the noise residual
- model_output = transformer(
- hidden_states=noisy_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timesteps,
- image_rotary_emb=image_rotary_emb,
- return_dict=False,
- )[0]
- model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)
-
- alphas_cumprod = scheduler.alphas_cumprod[timesteps]
- weights = 1 / (1 - alphas_cumprod)
- while len(weights.shape) < len(model_pred.shape):
- weights = weights.unsqueeze(-1)
-
- target = model_input
-
- loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
- loss = loss.mean()
- accelerator.backward(loss)
-
- if accelerator.sync_gradients:
- params_to_clip = transformer.parameters()
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
-
- if accelerator.state.deepspeed_plugin is None:
- optimizer.step()
- optimizer.zero_grad()
-
- lr_scheduler.step()
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- if accelerator.is_main_process:
- if global_step % args.checkpointing_steps == 0:
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
- if args.checkpoints_total_limit is not None:
- checkpoints = os.listdir(args.output_dir)
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
-
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
- if len(checkpoints) >= args.checkpoints_total_limit:
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
- removing_checkpoints = checkpoints[0:num_to_remove]
-
- logger.info(
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
- )
- logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
-
- for removing_checkpoint in removing_checkpoints:
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
- shutil.rmtree(removing_checkpoint)
-
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
- accelerator.save_state(save_path)
- logger.info(f"Saved state to {save_path}")
-
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
- accelerator.log(logs, step=global_step)
-
- if global_step >= args.max_train_steps:
- break
-
- if accelerator.is_main_process:
- if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
- # Create pipeline
- pipe = CogVideoXPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- transformer=unwrap_model(transformer),
- text_encoder=unwrap_model(text_encoder),
- vae=unwrap_model(vae),
- scheduler=scheduler,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
- )
-
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
- for validation_prompt in validation_prompts:
- pipeline_args = {
- "prompt": validation_prompt,
- "guidance_scale": args.guidance_scale,
- "use_dynamic_cfg": args.use_dynamic_cfg,
- "height": args.height,
- "width": args.width,
- }
-
- validation_outputs = log_validation(
- pipe=pipe,
- args=args,
- accelerator=accelerator,
- pipeline_args=pipeline_args,
- epoch=epoch,
- )
-
- # Save the lora layers
- accelerator.wait_for_everyone()
- if accelerator.is_main_process:
- transformer = unwrap_model(transformer)
- dtype = (
- torch.float16
- if args.mixed_precision == "fp16"
- else torch.bfloat16
- if args.mixed_precision == "bf16"
- else torch.float32
- )
- transformer = transformer.to(dtype)
- transformer_lora_layers = get_peft_model_state_dict(transformer)
-
- CogVideoXPipeline.save_lora_weights(
- save_directory=args.output_dir,
- transformer_lora_layers=transformer_lora_layers,
- )
-
- # Final test inference
- pipe = CogVideoXPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
- )
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
-
- if args.enable_slicing:
- pipe.vae.enable_slicing()
- if args.enable_tiling:
- pipe.vae.enable_tiling()
-
- # Load LoRA weights
- lora_scaling = args.lora_alpha / args.rank
- pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
- pipe.set_adapters(["cogvideox-lora"], [lora_scaling])
-
- # Run inference
- validation_outputs = []
- if args.validation_prompt and args.num_validation_videos > 0:
- validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
- for validation_prompt in validation_prompts:
- pipeline_args = {
- "prompt": validation_prompt,
- "guidance_scale": args.guidance_scale,
- "use_dynamic_cfg": args.use_dynamic_cfg,
- "height": args.height,
- "width": args.width,
- }
-
- video = log_validation(
- pipe=pipe,
- args=args,
- accelerator=accelerator,
- pipeline_args=pipeline_args,
- epoch=epoch,
- is_final_validation=True,
- )
- validation_outputs.extend(video)
-
- if args.push_to_hub:
- save_model_card(
- repo_id,
- videos=validation_outputs,
- base_model=args.pretrained_model_name_or_path,
- validation_prompt=args.validation_prompt,
- repo_folder=args.output_dir,
- fps=args.fps,
- )
- upload_folder(
- repo_id=repo_id,
- folder_path=args.output_dir,
- commit_message="End of training",
- ignore_patterns=["step_*", "epoch_*"],
- )
-
- accelerator.end_training()
-
-
-if __name__ == "__main__":
- args = get_args()
- main(args)
diff --git a/finetune/trainer.py b/finetune/trainer.py
new file mode 100644
index 0000000..6c9ec82
--- /dev/null
+++ b/finetune/trainer.py
@@ -0,0 +1,688 @@
+import os
+import logging
+import math
+import json
+
+import torch
+import transformers
+import diffusers
+import wandb
+
+from datetime import timedelta
+from pathlib import Path
+from tqdm import tqdm
+from typing import Dict, Any, List, Tuple
+from PIL import Image
+
+from torch.utils.data import Dataset, DataLoader
+from accelerate.logging import get_logger
+from accelerate.accelerator import Accelerator, DistributedType
+from accelerate.utils import (
+ DistributedDataParallelKwargs,
+ InitProcessGroupKwargs,
+ ProjectConfiguration,
+ set_seed,
+ gather_object,
+)
+
+from diffusers.optimization import get_scheduler
+from diffusers.utils.export_utils import export_to_video
+from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
+
+from finetune.schemas import Args, State, Components
+from finetune.utils import (
+ unwrap_model, cast_training_params,
+ get_optimizer,
+
+ get_memory_statistics,
+ free_memory,
+
+ get_latest_ckpt_path_to_resume_from,
+ get_intermediate_ckpt_path,
+ get_latest_ckpt_path_to_resume_from,
+ get_intermediate_ckpt_path,
+
+ string_to_filename
+)
+from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
+from finetune.datasets.utils import (
+ load_prompts, load_images, load_videos,
+ preprocess_image_with_resize, preprocess_video_with_resize
+)
+
+from finetune.constants import LOG_NAME, LOG_LEVEL
+
+
+logger = get_logger(LOG_NAME, LOG_LEVEL)
+
+_DTYPE_MAP = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+class Trainer:
+
+ def __init__(self, args: Args) -> None:
+ self.args = args
+ self.state = State(
+ weight_dtype=self.__get_training_dtype(),
+ train_frames=self.args.train_resolution[0],
+ train_height=self.args.train_resolution[1],
+ train_width=self.args.train_resolution[2]
+ )
+
+ self.components = Components()
+ self.accelerator: Accelerator = None
+ self.dataset: Dataset = None
+ self.data_loader: DataLoader = None
+
+ self.optimizer = None
+ self.lr_scheduler = None
+
+ self._init_distributed()
+ self._init_logging()
+ self._init_directories()
+
+ def _init_distributed(self):
+ logging_dir = Path(self.args.output_dir, "logs")
+ project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ init_process_group_kwargs = InitProcessGroupKwargs(
+ backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
+ )
+ mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
+ report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
+
+ accelerator = Accelerator(
+ project_config=project_config,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ mixed_precision=mixed_precision,
+ log_with=report_to,
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ self.accelerator = accelerator
+
+ if self.args.seed is not None:
+ set_seed(self.args.seed)
+
+ def _init_logging(self) -> None:
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=LOG_LEVEL,
+ )
+ if self.accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ logger.info("Initialized Trainer")
+ logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False)
+
+ def _init_directories(self) -> None:
+ if self.accelerator.is_main_process:
+ self.args.output_dir = Path(self.args.output_dir)
+ self.args.output_dir.mkdir(parents=True, exist_ok=True)
+
+ def prepare_models(self) -> None:
+ logger.info("Initializing models")
+
+ # Initialize model components
+ self.components = self.load_components()
+
+ if self.components.vae is not None:
+ if self.args.enable_slicing:
+ self.components.vae.enable_slicing()
+ if self.args.enable_tiling:
+ self.components.vae.enable_tiling()
+
+ self.state.transformer_config = self.components.transformer.config
+
+ def prepare_dataset(self) -> None:
+ logger.info("Initializing dataset and dataloader")
+
+ if self.args.model_type == "i2v":
+ self.dataset = I2VDatasetWithResize(
+ **(self.args.model_dump()),
+ device=self.accelerator.device,
+ encode_video_fn=self.encode_video,
+ max_num_frames=self.state.train_frames,
+ height=self.state.train_height,
+ width=self.state.train_width
+ )
+ elif self.args.model_type == "t2v":
+ self.dataset = T2VDatasetWithResize(
+ **(self.args.model_dump()),
+ device=self.accelerator.device,
+ encode_video_fn=self.encode_video,
+ max_num_frames=self.state.train_frames,
+ height=self.state.train_height,
+ width=self.state.train_width
+ )
+ else:
+ raise ValueError(f"Invalid model type: {self.args.model_type}")
+
+ # Prepare VAE for encoding
+ self.components.vae = self.components.vae.to(self.accelerator.device)
+ self.components.vae.requires_grad_(False)
+
+ # Precompute latent for video
+ logger.info("Precomputing latent for video ...")
+ tmp_data_loader = torch.utils.data.DataLoader(
+ self.dataset,
+ collate_fn=self.collate_fn,
+ batch_size=1,
+ num_workers=0,
+ pin_memory=self.args.pin_memory,
+ )
+ tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
+ for _ in tmp_data_loader: ...
+ logger.info("Precomputing latent for video ... Done")
+
+ self.data_loader = torch.utils.data.DataLoader(
+ self.dataset,
+ collate_fn=self.collate_fn,
+ batch_size=self.args.batch_size,
+ num_workers=self.args.num_workers,
+ pin_memory=self.args.pin_memory,
+ shuffle=True
+ )
+
+
+ def prepare_trainable_parameters(self):
+ logger.info("Initializing trainable parameters")
+
+ # For now only lora is supported
+ for attr_name, component in vars(self.components).items():
+ if hasattr(component, 'requires_grad_'):
+ component.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = self.state.weight_dtype
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ self.__move_components_to_device()
+
+ if self.args.gradient_checkpointing:
+ self.components.transformer.enable_gradient_checkpointing()
+
+ transformer_lora_config = LoraConfig(
+ r=self.args.rank,
+ lora_alpha=self.args.lora_alpha,
+ init_lora_weights=True,
+ target_modules=self.args.target_modules,
+ )
+ self.components.transformer.add_adapter(transformer_lora_config)
+ self.__prepare_saving_loading_hooks(transformer_lora_config)
+
+ def prepare_optimizer(self) -> None:
+ logger.info("Initializing optimizer and lr scheduler")
+
+ # Make sure the trainable params are in float32
+ if self.args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params([self.components.transformer], dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
+ transformer_parameters_with_lr = {
+ "params": transformer_lora_parameters,
+ "lr": self.args.learning_rate,
+ }
+ params_to_optimize = [transformer_parameters_with_lr]
+ self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)
+
+ use_deepspeed_opt = (
+ self.accelerator.state.deepspeed_plugin is not None
+ and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config
+ )
+ optimizer = get_optimizer(
+ params_to_optimize=params_to_optimize,
+ optimizer_name=self.args.optimizer,
+ learning_rate=self.args.learning_rate,
+ beta1=self.args.beta1,
+ beta2=self.args.beta2,
+ beta3=self.args.beta3,
+ epsilon=self.args.epsilon,
+ weight_decay=self.args.weight_decay,
+ use_deepspeed=use_deepspeed_opt,
+ )
+
+ num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
+ if self.args.train_steps is None:
+ self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
+ self.state.overwrote_max_train_steps = True
+
+ use_deepspeed_lr_scheduler = (
+ self.accelerator.state.deepspeed_plugin is not None
+ and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config
+ )
+ total_training_steps = self.args.train_steps * self.accelerator.num_processes
+ num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes
+
+ if use_deepspeed_lr_scheduler:
+ from accelerate.utils import DummyScheduler
+
+ lr_scheduler = DummyScheduler(
+ name=self.args.lr_scheduler,
+ optimizer=optimizer,
+ total_num_steps=total_training_steps,
+ num_warmup_steps=num_warmup_steps,
+ )
+ else:
+ lr_scheduler = get_scheduler(
+ name=self.args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_training_steps,
+ num_cycles=self.args.lr_num_cycles,
+ power=self.args.lr_power,
+ )
+
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+
+ def prepare_for_training(self) -> None:
+ self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare(
+ self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
+ if self.state.overwrote_max_train_steps:
+ self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
+ self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
+
+ def prepare_for_validation(self):
+ validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
+
+ if self.args.validation_images is not None:
+ validation_images = load_images(self.args.validation_dir / self.args.validation_images)
+ else:
+ validation_images = [None] * len(validation_prompts)
+
+ if self.args.validation_videos is not None:
+ validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos)
+ else:
+ validation_videos = [None] * len(validation_prompts)
+
+ self.state.validation_prompts = validation_prompts
+ self.state.validation_images = validation_images
+ self.state.validation_videos = validation_videos
+
+ def prepare_trackers(self) -> None:
+ logger.info("Initializing trackers")
+
+ tracker_name = self.args.tracker_name or "finetrainers-experiment"
+ self.accelerator.init_trackers(tracker_name, config=self.args.model_dump())
+
+ def train(self) -> None:
+ logger.info("Starting training")
+
+ memory_statistics = get_memory_statistics()
+ logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
+
+ self.state.total_batch_size_count = (
+ self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps
+ )
+ info = {
+ "trainable parameters": self.state.num_trainable_parameters,
+ "total samples": len(self.dataset),
+ "train epochs": self.args.train_epochs,
+ "train steps": self.args.train_steps,
+ "batches per device": self.args.batch_size,
+ "total batches observed per epoch": len(self.data_loader),
+ "train batch size total count": self.state.total_batch_size_count,
+ "gradient accumulation steps": self.args.gradient_accumulation_steps,
+ }
+ logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
+
+ global_step = 0
+ first_epoch = 0
+ initial_global_step = 0
+
+ # Potentially load in the weights and states from a previous save
+ (
+ resume_from_checkpoint_path,
+ initial_global_step,
+ global_step,
+ first_epoch,
+ ) = get_latest_ckpt_path_to_resume_from(
+ resume_from_checkpoint=self.args.resume_from_checkpoint,
+ num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
+ )
+ if resume_from_checkpoint_path is not None:
+ self.accelerator.load_state(resume_from_checkpoint_path)
+
+ progress_bar = tqdm(
+ range(0, self.args.train_steps),
+ initial=initial_global_step,
+ desc="Training steps",
+ disable=not self.accelerator.is_local_main_process,
+ )
+
+ accelerator = self.accelerator
+ generator = torch.Generator(device=accelerator.device)
+ if self.args.seed is not None:
+ generator = generator.manual_seed(self.args.seed)
+ self.state.generator = generator
+
+ for epoch in range(first_epoch, self.args.train_epochs):
+ logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
+
+ self.components.transformer.train()
+ models_to_accumulate = [self.components.transformer]
+
+ for step, batch in enumerate(self.data_loader):
+ logger.debug(f"Starting step {step + 1}")
+ logs = {}
+
+ with accelerator.accumulate(models_to_accumulate):
+ # These weighting schemes use a uniform timestep sampling and instead post-weight the loss
+ loss = self.compute_loss(batch)
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
+ grad_norm = self.components.transformer.get_global_grad_norm()
+ # In some cases the grad norm may not return a float
+ if torch.is_tensor(grad_norm):
+ grad_norm = grad_norm.item()
+ else:
+ grad_norm = accelerator.clip_grad_norm_(
+ self.components.transformer.parameters(), self.args.max_grad_norm
+ )
+ if torch.is_tensor(grad_norm):
+ grad_norm = grad_norm.item()
+
+ logs["grad_norm"] = grad_norm
+
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ self.optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ self.__maybe_save_checkpoint(global_step)
+
+ # Maybe run validation
+ should_run_validation = (
+ self.args.do_validation
+ and global_step % self.args.validation_steps == 0
+ )
+ if should_run_validation:
+ self.validate(global_step)
+
+ logs["loss"] = loss.detach().item()
+ logs["lr"] = self.lr_scheduler.get_last_lr()[0]
+ progress_bar.set_postfix(logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= self.args.train_steps:
+ break
+
+ memory_statistics = get_memory_statistics()
+ logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
+
+ accelerator.wait_for_everyone()
+ self.__maybe_save_checkpoint(global_step, must_save=True)
+ if self.args.do_validation:
+ self.validate(global_step)
+
+ del self.components
+ free_memory()
+ memory_statistics = get_memory_statistics()
+ logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
+
+ accelerator.end_training()
+
+ def validate(self, step: int) -> None:
+ logger.info("Starting validation")
+
+ accelerator = self.accelerator
+ num_validation_samples = len(self.state.validation_prompts)
+
+ if num_validation_samples == 0:
+ logger.warning("No validation samples found. Skipping validation.")
+ return
+
+ self.components.transformer.eval()
+
+ memory_statistics = get_memory_statistics()
+ logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
+
+ all_processes_artifacts = []
+ for i in range(num_validation_samples):
+ # Skip current validation on all processes but one
+ if i % accelerator.num_processes != accelerator.process_index:
+ continue
+
+ prompt = self.state.validation_prompts[i]
+ image = self.state.validation_images[i]
+ video = self.state.validation_videos[i]
+
+ if image is not None:
+ image = preprocess_image_with_resize(
+ image, self.state.train_height, self.state.train_width
+ )
+ # Convert image tensor (C, H, W) to PIL images
+ image = image.to(torch.uint8)
+ image = image.permute(1, 2, 0).cpu().numpy()
+ image = Image.fromarray(image)
+
+ if video is not None:
+ video = preprocess_video_with_resize(
+ video, self.state.train_frames, self.state.train_height, self.state.train_width
+ )
+ # Convert video tensor (F, C, H, W) to list of PIL images
+ video = (video * 255).round().clamp(0, 255).to(torch.uint8)
+ video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
+
+ logger.debug(
+ f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
+ main_process_only=False,
+ )
+ validation_artifacts = self.validation_step({
+ "prompt": prompt,
+ "image": image,
+ "video": video
+ })
+ prompt_filename = string_to_filename(prompt)[:25]
+ artifacts = {
+ "image": {"type": "image", "value": image},
+ "video": {"type": "video", "value": video},
+ }
+ for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
+ artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
+ logger.debug(
+ f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
+ main_process_only=False,
+ )
+
+ for key, value in list(artifacts.items()):
+ artifact_type = value["type"]
+ artifact_value = value["value"]
+ if artifact_type not in ["image", "video"] or artifact_value is None:
+ continue
+
+ extension = "png" if artifact_type == "image" else "mp4"
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
+ validation_path = self.args.output_dir / "validation_res"
+ validation_path.mkdir(parents=True, exist_ok=True)
+ filename = str(validation_path / filename)
+
+ if artifact_type == "image":
+ logger.debug(f"Saving image to {filename}")
+ artifact_value.save(filename)
+ artifact_value = wandb.Image(filename)
+ elif artifact_type == "video":
+ logger.debug(f"Saving video to {filename}")
+ export_to_video(artifact_value, filename, fps=self.args.gen_fps)
+ artifact_value = wandb.Video(filename, caption=prompt)
+
+ all_processes_artifacts.append(artifact_value)
+
+ all_artifacts = gather_object(all_processes_artifacts)
+
+ if accelerator.is_main_process:
+ tracker_key = "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
+ video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
+ tracker.log(
+ {
+ tracker_key: {"images": image_artifacts, "videos": video_artifacts},
+ },
+ step=step,
+ )
+
+ accelerator.wait_for_everyone()
+
+ free_memory()
+ memory_statistics = get_memory_statistics()
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
+
+ self.components.transformer.train()
+
+ def fit(self):
+ self.prepare_models()
+ self.prepare_dataset()
+ self.prepare_trainable_parameters()
+ self.prepare_optimizer()
+ self.prepare_for_training()
+ if self.args.do_validation:
+ self.prepare_for_validation()
+ self.prepare_trackers()
+ self.train()
+
+ def collate_fn(self, examples: List[Dict[str, Any]]):
+ raise NotImplementedError
+
+ def load_components(self) -> Components:
+ raise NotImplementedError
+
+ def encode_video(self, video: torch.Tensor) -> torch.Tensor:
+ # shape of input video: [B, C, F, H, W], where B = 1
+ raise NotImplementedError
+
+ def compute_loss(self, batch) -> torch.Tensor:
+ raise NotImplementedError
+
+ def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
+ raise NotImplementedError
+
+ def __get_training_dtype(self) -> torch.dtype:
+ if self.args.mixed_precision == "no":
+ return _DTYPE_MAP["fp32"]
+ elif self.args.mixed_precision == "fp16":
+ return _DTYPE_MAP["fp16"]
+ elif self.args.mixed_precision == "bf16":
+ return _DTYPE_MAP["bf16"]
+ else:
+ raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
+
+ def __move_components_to_device(self):
+ components = self.components.model_dump()
+ for name, component in components.items():
+ if not isinstance(component, type) and hasattr(component, 'to'):
+ setattr(self.components, name, component.to(self.accelerator.device))
+
+ def __prepare_saving_loading_hooks(self, transformer_lora_config):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if self.accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(
+ unwrap_model(self.accelerator, model),
+ type(unwrap_model(self.accelerator, self.components.transformer)),
+ ):
+ model = unwrap_model(self.accelerator, model)
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"Unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ self.components.pipeline_cls.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ if not self.accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
+ if isinstance(
+ unwrap_model(self.accelerator, model),
+ type(unwrap_model(self.accelerator, self.components.transformer)),
+ ):
+ transformer_ = unwrap_model(self.accelerator, model)
+ else:
+ raise ValueError(
+ f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
+ )
+ else:
+ transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
+ self.args.model_path, subfolder="transformer"
+ )
+ transformer_.add_adapter(transformer_lora_config)
+
+ lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v
+ for k, v in lora_state_dict.items()
+ if k.startswith("transformer.")
+ }
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if self.args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params([transformer_])
+
+ self.accelerator.register_save_state_pre_hook(save_model_hook)
+ self.accelerator.register_load_state_pre_hook(load_model_hook)
+
+ def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
+ if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
+ if must_save or global_step % self.args.checkpointing_steps == 0:
+ save_path = get_intermediate_ckpt_path(
+ checkpointing_limit=self.args.checkpointing_limit,
+ step=global_step,
+ output_dir=self.args.output_dir,
+ )
+ self.accelerator.save_state(save_path)
diff --git a/finetune/utils/__init__.py b/finetune/utils/__init__.py
new file mode 100644
index 0000000..7becfd5
--- /dev/null
+++ b/finetune/utils/__init__.py
@@ -0,0 +1,5 @@
+from .torch_utils import *
+from .optimizer_utils import *
+from .memory_utils import *
+from .checkpointing import *
+from .file_utils import *
diff --git a/finetune/utils/checkpointing.py b/finetune/utils/checkpointing.py
new file mode 100644
index 0000000..1797153
--- /dev/null
+++ b/finetune/utils/checkpointing.py
@@ -0,0 +1,53 @@
+import os
+from pathlib import Path
+from typing import Tuple
+from accelerate.logging import get_logger
+
+from finetune.constants import LOG_NAME, LOG_LEVEL
+from ..utils.file_utils import find_files, delete_files
+
+
+logger = get_logger(LOG_NAME, LOG_LEVEL)
+
+
+def get_latest_ckpt_path_to_resume_from(
+ resume_from_checkpoint: str | None, num_update_steps_per_epoch: int
+) -> Tuple[str | None, int, int, int]:
+ if resume_from_checkpoint is None:
+ initial_global_step = 0
+ global_step = 0
+ first_epoch = 0
+ resume_from_checkpoint_path = None
+ else:
+ resume_from_checkpoint_path = Path(resume_from_checkpoint)
+ if not resume_from_checkpoint_path.exists():
+ logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
+ initial_global_step = 0
+ global_step = 0
+ first_epoch = 0
+ resume_from_checkpoint_path = None
+ else:
+ logger.info(f"Resuming from checkpoint {resume_from_checkpoint}")
+ global_step = int(resume_from_checkpoint_path.name.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
+
+
+def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
+ # before saving state, check if this save would set us over the `checkpointing_limit`
+ if checkpointing_limit is not None:
+ checkpoints = find_files(output_dir, prefix="checkpoint")
+
+ # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= checkpointing_limit:
+ num_to_remove = len(checkpoints) - checkpointing_limit + 1
+ checkpoints_to_remove = checkpoints[0:num_to_remove]
+ delete_files(checkpoints_to_remove)
+
+ logger.info(f"Checkpointing at step {step}")
+ save_path = os.path.join(output_dir, f"checkpoint-{step}")
+ logger.info(f"Saving state to {save_path}")
+ return save_path
diff --git a/finetune/utils/file_utils.py b/finetune/utils/file_utils.py
new file mode 100644
index 0000000..f04dd85
--- /dev/null
+++ b/finetune/utils/file_utils.py
@@ -0,0 +1,47 @@
+import logging
+import os
+import shutil
+
+from pathlib import Path
+from typing import Any, Dict, List, Union
+from accelerate.logging import get_logger
+from finetune.constants import LOG_NAME, LOG_LEVEL
+
+
+logger = get_logger(LOG_NAME, LOG_LEVEL)
+
+
+def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
+ if not isinstance(dir, Path):
+ dir = Path(dir)
+ if not dir.exists():
+ return []
+ checkpoints = os.listdir(dir.as_posix())
+ checkpoints = [c for c in checkpoints if c.startswith(prefix)]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+ checkpoints = [dir / c for c in checkpoints]
+ return checkpoints
+
+
+def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
+ if not isinstance(dirs, list):
+ dirs = [dirs]
+ dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
+ logger.info(f"Deleting files: {dirs}")
+ for dir in dirs:
+ if not dir.exists():
+ continue
+ shutil.rmtree(dir, ignore_errors=True)
+
+
+def string_to_filename(s: str) -> str:
+ return (
+ s.replace(" ", "-")
+ .replace("/", "-")
+ .replace(":", "-")
+ .replace(".", "-")
+ .replace(",", "-")
+ .replace(";", "-")
+ .replace("!", "-")
+ .replace("?", "-")
+ )
diff --git a/finetune/utils/memory_utils.py b/finetune/utils/memory_utils.py
new file mode 100644
index 0000000..a7a136b
--- /dev/null
+++ b/finetune/utils/memory_utils.py
@@ -0,0 +1,60 @@
+import gc
+import torch
+
+from typing import Any, Dict, Union
+from accelerate.logging import get_logger
+
+from finetune.constants import LOG_NAME, LOG_LEVEL
+
+
+logger = get_logger(LOG_NAME, LOG_LEVEL)
+
+
+def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
+ memory_allocated = None
+ memory_reserved = None
+ max_memory_allocated = None
+ max_memory_reserved = None
+
+ if torch.cuda.is_available():
+ device = torch.cuda.current_device()
+ memory_allocated = torch.cuda.memory_allocated(device)
+ memory_reserved = torch.cuda.memory_reserved(device)
+ max_memory_allocated = torch.cuda.max_memory_allocated(device)
+ max_memory_reserved = torch.cuda.max_memory_reserved(device)
+
+ elif torch.mps.is_available():
+ memory_allocated = torch.mps.current_allocated_memory()
+
+ else:
+ logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
+
+ return {
+ "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
+ "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
+ "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
+ "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
+ }
+
+
+def bytes_to_gigabytes(x: int) -> float:
+ if x is not None:
+ return x / 1024**3
+
+
+def free_memory() -> None:
+ if torch.cuda.is_available():
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+ # TODO(aryan): handle non-cuda devices
+
+
+def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if isinstance(x, torch.Tensor):
+ return x.contiguous()
+ elif isinstance(x, dict):
+ return {k: make_contiguous(v) for k, v in x.items()}
+ else:
+ return x
diff --git a/finetune/utils/optimizer_utils.py b/finetune/utils/optimizer_utils.py
new file mode 100644
index 0000000..bd93f9c
--- /dev/null
+++ b/finetune/utils/optimizer_utils.py
@@ -0,0 +1,180 @@
+import inspect
+import torch
+
+from accelerate.logging import get_logger
+
+from finetune.constants import LOG_NAME, LOG_LEVEL
+
+
+logger = get_logger(LOG_NAME, LOG_LEVEL)
+
+
+def get_optimizer(
+ params_to_optimize,
+ optimizer_name: str = "adam",
+ learning_rate: float = 1e-3,
+ beta1: float = 0.9,
+ beta2: float = 0.95,
+ beta3: float = 0.98,
+ epsilon: float = 1e-8,
+ weight_decay: float = 1e-4,
+ prodigy_decouple: bool = False,
+ prodigy_use_bias_correction: bool = False,
+ prodigy_safeguard_warmup: bool = False,
+ use_8bit: bool = False,
+ use_4bit: bool = False,
+ use_torchao: bool = False,
+ use_deepspeed: bool = False,
+ use_cpu_offload_optimizer: bool = False,
+ offload_gradients: bool = False,
+) -> torch.optim.Optimizer:
+ optimizer_name = optimizer_name.lower()
+
+ # Use DeepSpeed optimzer
+ if use_deepspeed:
+ from accelerate.utils import DummyOptim
+
+ return DummyOptim(
+ params_to_optimize,
+ lr=learning_rate,
+ betas=(beta1, beta2),
+ eps=epsilon,
+ weight_decay=weight_decay,
+ )
+
+ if use_8bit and use_4bit:
+ raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
+
+ if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
+ try:
+ import torchao
+
+ torchao.__version__
+ except ImportError:
+ raise ImportError(
+ "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
+ )
+
+ if not use_torchao and use_4bit:
+ raise ValueError("4-bit Optimizers are only supported with torchao.")
+
+ # Optimizer creation
+ supported_optimizers = ["adam", "adamw", "prodigy", "came"]
+ if optimizer_name not in supported_optimizers:
+ logger.warning(
+ f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
+ )
+ optimizer_name = "adamw"
+
+ if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
+ raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
+
+ if use_8bit:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ if optimizer_name == "adamw":
+ if use_torchao:
+ from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
+
+ optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
+ else:
+ optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
+
+ init_kwargs = {
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ }
+
+ elif optimizer_name == "adam":
+ if use_torchao:
+ from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
+
+ optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
+ else:
+ optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
+
+ init_kwargs = {
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ }
+
+ elif optimizer_name == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ init_kwargs = {
+ "lr": learning_rate,
+ "betas": (beta1, beta2),
+ "beta3": beta3,
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ "decouple": prodigy_decouple,
+ "use_bias_correction": prodigy_use_bias_correction,
+ "safeguard_warmup": prodigy_safeguard_warmup,
+ }
+
+ elif optimizer_name == "came":
+ try:
+ import came_pytorch
+ except ImportError:
+ raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
+
+ optimizer_class = came_pytorch.CAME
+
+ init_kwargs = {
+ "lr": learning_rate,
+ "eps": (1e-30, 1e-16),
+ "betas": (beta1, beta2, beta3),
+ "weight_decay": weight_decay,
+ }
+
+ if use_cpu_offload_optimizer:
+ from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
+
+ if "fused" in inspect.signature(optimizer_class.__init__).parameters:
+ init_kwargs.update({"fused": True})
+
+ optimizer = CPUOffloadOptimizer(
+ params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
+ )
+ else:
+ optimizer = optimizer_class(params_to_optimize, **init_kwargs)
+
+ return optimizer
+
+
+def gradient_norm(parameters):
+ norm = 0
+ for param in parameters:
+ if param.grad is None:
+ continue
+ local_norm = param.grad.detach().data.norm(2)
+ norm += local_norm.item() ** 2
+ norm = norm**0.5
+ return norm
+
+
+def max_gradient(parameters):
+ max_grad_value = float("-inf")
+ for param in parameters:
+ if param.grad is None:
+ continue
+ local_max_grad = param.grad.detach().data.abs().max()
+ max_grad_value = max(max_grad_value, local_max_grad.item())
+ return max_grad_value
diff --git a/finetune/utils/torch_utils.py b/finetune/utils/torch_utils.py
new file mode 100644
index 0000000..8a74271
--- /dev/null
+++ b/finetune/utils/torch_utils.py
@@ -0,0 +1,52 @@
+from typing import Dict, Optional, Union, List
+
+import torch
+from accelerate import Accelerator
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+def unwrap_model(accelerator: Accelerator, model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+
+def align_device_and_dtype(
+ x: Union[torch.Tensor, Dict[str, torch.Tensor]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+):
+ if isinstance(x, torch.Tensor):
+ if device is not None:
+ x = x.to(device)
+ if dtype is not None:
+ x = x.to(dtype)
+ elif isinstance(x, dict):
+ if device is not None:
+ x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
+ if dtype is not None:
+ x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
+ return x
+
+
+def expand_tensor_to_dims(tensor, ndim):
+ while len(tensor.shape) < ndim:
+ tensor = tensor.unsqueeze(-1)
+ return tensor
+
+
+def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
+ """
+ Casts the training parameters of the model to the specified data type.
+
+ Args:
+ model: The PyTorch model whose parameters will be cast.
+ dtype: The data type to which the model parameters will be cast.
+ """
+ if not isinstance(model, list):
+ model = [model]
+ for m in model:
+ for param in m.parameters():
+ # only upcast trainable parameters into fp32
+ if param.requires_grad:
+ param.data = param.to(dtype)
\ No newline at end of file