finetune and infer upload

This commit is contained in:
zR 2024-09-16 12:02:27 +08:00
parent 098640337d
commit 6e64359524
11 changed files with 2107 additions and 19 deletions

192
finetune/README.md Normal file
View File

@ -0,0 +1,192 @@
# CogVideoX diffusers Fine-tuning Guide
If you want to see the SAT version fine-tuning, please check [here](../sat/README.md). The dataset format is different
from this version.
This tutorial aims to quickly fine-tune the diffusers version of the CogVideoX model.
### Hardware Requirements
+ CogVideoX-2B LORA: 1 * A100
+ CogVideoX-2B SFT: 8 * A100
+ CogVideoX-5B/5B-I2V not yet supported
### Prepare the Dataset
First, you need to prepare the dataset. The format of the dataset is as follows, where `videos.txt` contains paths to
the videos in the `videos` directory.
```
.
├── prompts.txt
├── videos
└── videos.txt
```
### Configuration Files and Execution
`accelerate` configuration files are as follows:
+ accelerate_config_machine_multi.yaml for multi-GPU use
+ accelerate_config_machine_single.yaml for single-GPU use
The `finetune` script configuration is as follows:
```shell
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# This command sets PyTorch's CUDA memory allocation strategy to segment-based memory management to prevent OOM (Out of Memory) errors.
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
# Use Accelerate to start training, specifying the `accelerate_config_machine_single.yaml` configuration file, and using multiple GPUs.
train_cogvideox_lora.py \
# This is the training script you will execute for LoRA fine-tuning of the CogVideoX model.
--pretrained_model_name_or_path THUDM/CogVideoX-2b \
# The path to the pretrained model, pointing to the CogVideoX-5b model you want to fine-tune.
--cache_dir ~/.cache \
# The directory where downloaded models and datasets will be stored.
--enable_tiling \
# Enable VAE tiling functionality, which reduces memory usage by processing smaller blocks of the image.
--enable_slicing \
# Enable VAE slicing functionality, which slices the image across channels to save memory.
--instance_data_root ~/disney/ \
# The root directory of the instance data, the folder of the dataset used during training.
--caption_column prompts.txt \
# Specifies the column or file containing instance prompts (text descriptions), in this case, the `prompts.txt` file.
--video_column videos.txt \
# Specifies the column or file containing paths to videos, in this case, the `videos.txt` file.
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \
# The prompt(s) used for validation, multiple prompts should be separated by the specified delimiter (`:::`).
--validation_prompt_separator ::: \
# The delimiter for validation prompts, set here as `:::`.
--num_validation_videos 1 \
# The number of videos to be generated during validation, set to 1.
--validation_epochs 2 \
# How many epochs to run validation, set to validate every 2 epochs.
--seed 3407 \
# Sets the random seed for reproducible training, set to 3407.
--rank 128 \
# The dimension of the LoRA update matrices, controlling the size of the LoRA layer parameters, set to 128.
--mixed_precision bf16 \
# Use mixed precision training, set to `bf16` (bfloat16), which can reduce memory usage and speed up training.
--output_dir cogvideox-lora-single-gpu \
# Output directory, where model predictions and checkpoints will be stored.
--height 480 \
# The height of input videos, all videos will be resized to 480 pixels.
--width 720 \
# The width of input videos, all videos will be resized to 720 pixels.
--fps 8 \
# The frame rate of input videos, all videos will be processed at 8 frames per second.
--max_num_frames 49 \
# The maximum number of frames for input videos, videos will be truncated to a maximum of 49 frames.
--skip_frames_start 0 \
# The number of frames to skip at the beginning of each video, set to 0, indicating no frames are skipped.
--skip_frames_end 0 \
# The number of frames to skip at the end of each video, set to 0, indicating no frames are skipped.
--train_batch_size 1 \
# The batch size for training, set to 1 per device.
--num_train_epochs 10 \
# The total number of epochs for training, set to 10.
--checkpointing_steps 500 \
# Save a checkpoint every 500 steps.
--gradient_accumulation_steps 1 \
# The number of gradient accumulation steps, indicating that a gradient update is performed every 1 step.
--learning_rate 1e-4 \
# The initial learning rate, set to 1e-4.
--optimizer AdamW \
# The type of optimizer, choosing AdamW.
--adam_beta1 0.9 \
# The beta1 parameter for the Adam optimizer, set to 0.9.
--adam_beta2 0.95 \
# The beta2 parameter for the Adam optimizer, set to 0.95.
```
### Run the script to start fine-tuning
Single GPU fine-tuning:
```shell
bash finetune_single_gpu.sh
```
Multi-GPU fine-tuning:
```shell
bash finetune_multi_gpus_1.sh # needs to be run on each node
```
### Best Practices
+ Include 70 videos with a resolution of `200 x 480 x 720` (frames x height x width). Through data preprocessing's frame
skipping, we created two smaller datasets of 49 and 16 frames to speed up experiments, as the CogVideoX team suggests
a maximum frame count of 49. We divided the 70 videos into three groups of 10, 25, and 50 videos. These videos are
conceptually similar.
+ 25 or more videos work best when training new concepts and styles.
+ Now using an identifier token specified through `--id_token` enhances training results. This is similar to Dreambooth
training, but regular fine-tuning without this token also works.
+ The original repository uses `lora_alpha` set to 1. We found this value to be ineffective in multiple runs, likely due
to differences in model backend and training setups. Our recommendation is to set lora_alpha to the same as rank or
rank // 2.
+ Using settings with a rank of 64 or above is recommended.

157
finetune/README_zh.md Normal file
View File

@ -0,0 +1,157 @@
# CogVideoX diffusers 微调方案
如果您想查看SAT版本微调请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
本教程旨在快速微调 diffusers 版本 CogVideoX 模型。
### 硬件要求
+ CogVideoX-2B LORA: 1 * A100
+ CogVideoX-2B SFT: 8 * A100
+ CogVideoX-5B/5B-I2V 暂未支持
### 准备数据集
首先你需要准备数据集数据集格式如下其中videos.txt 存放 videos 中的视频。
```
.
├── prompts.txt
├── videos
└── videos.txt
```
### 配置文件和运行
`accelerate` 配置文件如下:
+ accelerate_config_machine_multi.yaml 适合多GPU使用
+ accelerate_config_machine_single.yaml 适合单GPU使用
`finetune` 脚本配置文件如下:
```shell
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# 这条命令设置了 PyTorch 的 CUDA 内存分配策略,将显存扩展为段式内存管理,以防止 OOMOut of Memory错误。
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
# 使用 Accelerate 启动训练,指定配置文件 `accelerate_config_machine_single.yaml`,并使用多 GPU。
train_cogvideox_lora.py \
# 这是你要执行的训练脚本,用于 LoRA 微调 CogVideoX 模型。
--pretrained_model_name_or_path THUDM/CogVideoX-2b \
# 预训练模型的路径,指向你要微调的 CogVideoX-5b 模型。
--cache_dir ~/.cache \
# 模型缓存的目录,用于存储从 Hugging Face 下载的模型和数据集。
--enable_tiling \
# 启用 VAE tiling 功能,通过将图像划分成更小的区块处理,减少显存占用。
--enable_slicing \
# 启用 VAE slicing 功能,将图像在通道上切片处理,以节省显存。
--instance_data_root ~/disney/ \
# 实例数据的根目录,训练时使用的数据集文件夹。
--caption_column prompts.txt \
# 用于指定包含实例提示(文本描述)的列或文件,在本例中为 `prompts.txt` 文件。
--video_column videos.txt \
# 用于指定包含视频路径的列或文件,在本例中为 `videos.txt` 文件。
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \
# 用于验证的提示语,多个提示语用指定分隔符(例如 `:::`)分开。
--validation_prompt_separator ::: \
# 验证提示语的分隔符,在此设置为 `:::`
--num_validation_videos 1 \
# 验证期间生成的视频数量,设置为 1。
--validation_epochs 2 \
# 每隔多少个 epoch 运行一次验证,设置为每 2 个 epoch 验证一次。
--seed 3407 \
# 设置随机数种子,确保训练的可重复性,设置为 3407。
--rank 128 \
# LoRA 更新矩阵的维度,控制 LoRA 层的参数大小,设置为 128。
--mixed_precision bf16 \
# 使用混合精度训练,设置为 `bf16`bfloat16可以减少显存占用并加速训练。
--output_dir cogvideox-lora-single-gpu \
# 输出目录,存放模型预测结果和检查点。
--height 480 \
# 输入视频的高度,所有视频将被调整到 480 像素。
--width 720 \
# 输入视频的宽度,所有视频将被调整到 720 像素。
--fps 8 \
# 输入视频的帧率,所有视频将以每秒 8 帧处理。
--max_num_frames 49 \
# 输入视频的最大帧数,视频将被截取到最多 49 帧。
--skip_frames_start 0 \
# 每个视频从头部开始跳过的帧数,设置为 0表示不跳过帧。
--skip_frames_end 0 \
# 每个视频从尾部跳过的帧数,设置为 0表示不跳过尾帧。
--train_batch_size 1 \
# 训练的批次大小,每个设备的训练批次设置为 1。
--num_train_epochs 10 \
# 训练的总 epoch 数,设置为 10。
--checkpointing_steps 500 \
# 每经过 500 步保存一次检查点。
--gradient_accumulation_steps 1 \
# 梯度累积步数,表示每进行 1 步才进行一次梯度更新。
--learning_rate 1e-4 \
# 初始学习率,设置为 1e-4。
--optimizer AdamW \
# 优化器类型,选择 AdamW 优化器。
--adam_beta1 0.9 \
# Adam 优化器的 beta1 参数,设置为 0.9。
--adam_beta2 0.95 \
# Adam 优化器的 beta2 参数,设置为 0.95。
```
### 运行脚本,开始微调
单卡微调:
```shell
bash finetune_single_gpu.sh
```
多卡微调:
```shell
bash finetune_multi_gpus_1.sh #需要在每个节点运行
```
### 最佳实践
+ 包含70个分辨率为 `200 x 480 x 720`(帧数 x 高 x
的训练视频。通过数据预处理中的帧跳过我们创建了两个较小的49帧和16帧数据集以加快实验速度因为CogVideoX团队建议的最大帧数限制是49帧。我们将70个视频分成三组分别为10、25和50个视频。这些视频的概念性质相似。
+ 25个及以上的视频在训练新概念和风格时效果最佳。
+ 现使用可以通过 `--id_token` 指定的标识符token进行训练效果更好。这类似于 Dreambooth 训练但不使用这种token的常规微调也可以工作。
+ 原始仓库使用 `lora_alpha` 设置为 1。我们发现这个值在多次运行中效果不佳可能是因为模型后端和训练设置的不同。我们的建议是将
lora_alpha 设置为与 rank 相同或 rank // 2。
+ 建议使用 rank 为 64 及以上的设置。

View File

@ -0,0 +1,26 @@
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
deepspeed_hostfile: hostfile.txt
deepspeed_multinode_launcher: pdsh
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
enable_cpu_affinity: true
main_process_ip: 10.250.128.19
main_process_port: 12355
main_training_function: main
mixed_precision: bf16
num_machines: 4
num_processes: 32
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,24 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,39 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="disney"
export OUTPUT_PATH="cogvideox-lora-multi-gpu"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
train_cogvideox_lora.py \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--video_column videos.txt \
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 2 \
--seed 3407 \
--rank 128 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 10 \
--checkpointing_steps 500 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-4 \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95

View File

@ -0,0 +1,39 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="disney"
export OUTPUT_PATH="cogvideox-lora-single-gpu"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
train_cogvideox_lora.py \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--video_column videos.txt \
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 2 \
--seed 3407 \
--rank 128 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 10 \
--checkpointing_steps 500 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-4 \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95

4
finetune/hostfile.txt Normal file
View File

@ -0,0 +1,4 @@
node1 slots=8
node2 slots=8
node3 slots=8
node4 slots=8

File diff suppressed because it is too large Load Diff

View File

@ -77,7 +77,9 @@ def generate_video(
# 2. Set Scheduler.
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
# We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B and `CogVideoXDPMScheduler` for CogVideoX-5B.
# We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B.
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
@ -85,14 +87,16 @@ def generate_video(
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
# and enable to("cuda")
# pipe.enable_sequential_cpu_offload()
pipe.to("cuda")
# pipe.to("cuda")
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# 4. Generate the video frames based on the prompt.
# `num_frames` is the Number of frames to generate.
# This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame and 49 frames.
# This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
if generate_type == "i2v":
video_generate = pipe(
prompt=prompt,
@ -100,7 +104,7 @@ def generate_video(
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
num_inference_steps=num_inference_steps, # Number of inference steps
num_frames=49, # Number of frames to generatechanged to 49 for diffusers version `0.30.3` and after.
use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False
use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
).frames[0]

View File

@ -6,7 +6,7 @@ Note:
Must install the `torchao``torch`,`diffusers`,`accelerate` library FROM SOURCE to use the quantization feature.
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
ALL quantization schemes must using with NVIDIA GPUs.
ALL quantization schemes must use with NVIDIA GPUs.
# Run the script:
@ -83,8 +83,9 @@ def generate_video(
# Using with compile will run faster. First time infer will cost ~30min to compile.
# pipe.transformer.to(memory_format=torch.channels_last)
# for FP8 should remove pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload()
# for FP8 should remove pipe.enable_sequential_cpu_offload()
pipe.enable_sequential_cpu_offload()
# This is not for FP8 and INT8 and should remove this line
# pipe.enable_sequential_cpu_offload()
@ -95,7 +96,7 @@ def generate_video(
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
num_frames=49,
use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]

View File

@ -1,18 +1,22 @@
"""
This script is designed to demonstrate how to use the CogVideoX-2b VAE model for video encoding and decoding.
It allows you to encode a video into a latent representation, decode it back into a video, or perform both operations sequentially.
Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the `{your local diffusers path}` argument to the path of the cloned repository.
Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the
`{your local diffusers path}` argument to the path of the cloned repository.
Command 1: Encoding Video
Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-2b VAE model.
Memory Usage: ~34GB of GPU memory for encoding.
If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder and you can still run the decoding command.
Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-5b VAE model.
Memory Usage: ~18GB of GPU memory for encoding.
If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder,
and you can still run the decoding command.
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode encode
Command 2: Decoding Video
Decodes the latent representation stored in encoded.pt back into a video.
Memory Usage: ~19GB of GPU memory for decoding.
Memory Usage: ~4GB of GPU memory for decoding.
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --encoded_path ./encoded.pt --mode decode
Command 3: Encoding and Decoding Video
@ -24,9 +28,9 @@ $ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/v
import argparse
import torch
import imageio
import numpy as np
from diffusers import AutoencoderKLCogVideoX
from torchvision import transforms
import numpy as np
def encode_video(model_path, video_path, dtype, device):
@ -42,7 +46,12 @@ def encode_video(model_path, video_path, dtype, device):
Returns:
- torch.Tensor: The encoded video frames.
"""
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
model.enable_slicing()
model.enable_tiling()
video_reader = imageio.get_reader(video_path, "ffmpeg")
frames = [transforms.ToTensor()(frame) for frame in video_reader]
@ -80,13 +89,13 @@ def save_video(tensor, output_path):
Saves the video frames to a video file.
Parameters:
- tensor (torch.Tensor): The video frames tensor.
- tensor (torch.Tensor): The video frames' tensor.
- output_path (str): The path to save the output video.
"""
tensor = tensor.to(dtype=torch.float32)
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
frames = np.clip(frames, 0, 1) * 255
frames = frames.astype(np.uint8)
writer = imageio.get_writer(output_path + "/output.mp4", fps=8)
for frame in frames:
writer.append_data(frame)
@ -103,7 +112,7 @@ if __name__ == "__main__":
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
)
parser.add_argument(
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
@ -111,7 +120,7 @@ if __name__ == "__main__":
args = parser.parse_args()
device = torch.device(args.device)
dtype = torch.float16 if args.dtype == "float16" else torch.float32
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
if args.mode == "encode":
assert args.video_path, "Video path must be provided for encoding."