Merge branch 'THUDM:main' into main

This commit is contained in:
Chenxi 2024-09-22 17:43:30 +01:00 committed by GitHub
commit 7a01f14400
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 176 additions and 254 deletions

View File

@ -277,7 +277,10 @@ pipe.vae.enable_tiling()
We highly welcome contributions from the community and actively contribute to the open-source community. The following
works have already been adapted for CogVideoX, and we invite everyone to use them:
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun is a modified pipeline based on the CogVideoX architecture, supporting flexible resolutions and multiple launch methods.
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun is a modified pipeline based on the
CogVideoX architecture, supporting flexible resolutions and multiple launch methods.
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): A separate repository for CogVideo's Gradio Web UI, which
supports more functional Web UIs.
+ [Xorbits Inference](https://github.com/xorbitsai/inference): A powerful and comprehensive distributed inference
framework, allowing you to easily deploy your own models or the latest cutting-edge open-source models with just one
click.
@ -288,7 +291,8 @@ works have already been adapted for CogVideoX, and we invite everyone to use the
techniques.
+ [AutoDL Space](https://www.codewithgpu.com/i/THUDM/CogVideo/CogVideoX-5b-demo): A one-click deployment Huggingface
Space image provided by community members.
+ [Interior Design Fine-Tuning Model](https://huggingface.co/collections/bertjiazheng/koolcogvideox-66e4762f53287b7f39f8f3ba): is a fine-tuned model based on CogVideoX, specifically designed for interior design.
+ [Interior Design Fine-Tuning Model](https://huggingface.co/collections/bertjiazheng/koolcogvideox-66e4762f53287b7f39f8f3ba):
is a fine-tuned model based on CogVideoX, specifically designed for interior design.
## Project Structure

View File

@ -262,6 +262,7 @@ pipe.vae.enable_tiling()
コミュニティからの貢献を大歓迎し、私たちもオープンソースコミュニティに積極的に貢献しています。以下の作品はすでにCogVideoXに対応しており、ぜひご利用ください
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Funは、CogVideoXアーキテクチャを基にした改良パイプラインで、自由な解像度と複数の起動方法をサポートしています。
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): CogVideo の Gradio Web UI の別のリポジトリ。より高機能な Web UI をサポートします。
+ [Xorbits Inference](https://github.com/xorbitsai/inference):
強力で包括的な分散推論フレームワークであり、ワンクリックで独自のモデルや最新のオープンソースモデルを簡単にデプロイできます。
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper)

View File

@ -249,6 +249,7 @@ pipe.vae.enable_tiling()
我们非常欢迎来自社区的贡献并积极的贡献开源社区。以下作品已经对CogVideoX进行了适配欢迎大家使用:
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun是一个基于CogVideoX结构修改后的的pipeline支持自由的分辨率多种启动方式。
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): CogVideo 的 Gradio Web UI单独实现仓库支持更多功能的 Web UI。
+ [Xorbits Inference](https://github.com/xorbitsai/inference): 性能强大且功能全面的分布式推理框架,轻松一键部署你自己的模型或内置的前沿开源模型。
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper) 使用ComfyUI框架将CogVideoX加入到你的工作流中。
+ [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys): VideoSys 提供了易用且高性能的视频生成基础设施,支持完整的管道,并持续集成最新的模型和技术。

View File

@ -51,82 +51,57 @@ The `accelerate` configuration files are as follows:
The configuration for the `finetune` script is as follows:
```shell
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# This command sets the PyTorch CUDA memory allocation strategy to expandable segments to prevent OOM (Out of Memory) errors.
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu # Launch training using Accelerate with the specified config file for multi-GPU.
train_cogvideox_lora.py # This is the training script for LoRA fine-tuning of the CogVideoX model.
--pretrained_model_name_or_path THUDM/CogVideoX-2b # Path to the pretrained model you want to fine-tune, pointing to the CogVideoX-2b model.
--cache_dir ~/.cache # Directory for caching models downloaded from Hugging Face.
--enable_tiling # Enable VAE tiling to reduce memory usage by processing images in smaller chunks.
--enable_slicing # Enable VAE slicing to split the image into slices along the channel to save memory.
--instance_data_root ~/disney/ # Root directory for instance data, i.e., the dataset used for training.
--caption_column prompts.txt # Specify the column or file containing instance prompts (text descriptions), in this case, the `prompts.txt` file.
--video_column videos.txt # Specify the column or file containing video paths, in this case, the `videos.txt` file.
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" # Validation prompts; multiple prompts are separated by the specified delimiter (e.g., `:::`).
--validation_prompt_separator ::: # The separator for validation prompts, set to `:::` here.
--num_validation_videos 1 # Number of videos to generate during validation, set to 1.
--validation_epochs 2 # Number of epochs after which validation will be run, set to every 2 epochs.
--seed 3407 # Set a random seed to ensure reproducibility, set to 3407.
--rank 128 # Dimension of the LoRA update matrix, controls the size of the LoRA layers, set to 128.
--mixed_precision bf16 # Use mixed precision training, set to `bf16` (bfloat16) to reduce memory usage and speed up training.
--output_dir cogvideox-lora-single-gpu # Output directory for storing model predictions and checkpoints.
--height 480 # Height of the input videos, all videos will be resized to 480 pixels.
--width 720 # Width of the input videos, all videos will be resized to 720 pixels.
--fps 8 # Frame rate of the input videos, all videos will be processed at 8 frames per second.
--max_num_frames 49 # Maximum number of frames per input video, videos will be truncated to 49 frames.
--skip_frames_start 0 # Number of frames to skip from the start of each video, set to 0 to not skip any frames.
--skip_frames_end 0 # Number of frames to skip from the end of each video, set to 0 to not skip any frames.
--train_batch_size 1 # Training batch size per device, set to 1.
--num_train_epochs 10 # Total number of training epochs, set to 10.
--checkpointing_steps 500 # Save checkpoints every 500 steps.
--gradient_accumulation_steps 1 # Gradient accumulation steps, perform an update every 1 step.
--learning_rate 1e-4 # Initial learning rate, set to 1e-4.
--optimizer AdamW # Optimizer type, using AdamW optimizer.
--adam_beta1 0.9 # Beta1 parameter for the Adam optimizer, set to 0.9.
--adam_beta2 0.95 # Beta2 parameter for the Adam optimizer, set to 0.95.
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # Use accelerate to launch multi-GPU training with the config file accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # Training script train_cogvideox_lora.py for LoRA fine-tuning on CogVideoX model
--gradient_checkpointing \ # Enable gradient checkpointing to reduce memory usage
--pretrained_model_name_or_path $MODEL_PATH \ # Path to the pretrained model, specified by $MODEL_PATH
--cache_dir $CACHE_PATH \ # Cache directory for model files, specified by $CACHE_PATH
--enable_tiling \ # Enable tiling technique to process videos in chunks, saving memory
--enable_slicing \ # Enable slicing to further optimize memory by slicing inputs
--instance_data_root $DATASET_PATH \ # Dataset path specified by $DATASET_PATH
--caption_column prompts.txt \ # Specify the file prompts.txt for video descriptions used in training
--video_column videos.txt \ # Specify the file videos.txt for video paths used in training
--validation_prompt "" \ # Prompt used for generating validation videos during training
--validation_prompt_separator ::: \ # Set ::: as the separator for validation prompts
--num_validation_videos 1 \ # Generate 1 validation video per validation round
--validation_epochs 100 \ # Perform validation every 100 training epochs
--seed 42 \ # Set random seed to 42 for reproducibility
--rank 128 \ # Set the rank for LoRA parameters to 128
--lora_alpha 64 \ # Set the alpha parameter for LoRA to 64, adjusting LoRA learning rate
--mixed_precision bf16 \ # Use bf16 mixed precision for training to save memory
--output_dir $OUTPUT_PATH \ # Specify the output directory for the model, defined by $OUTPUT_PATH
--height 480 \ # Set video height to 480 pixels
--width 720 \ # Set video width to 720 pixels
--fps 8 \ # Set video frame rate to 8 frames per second
--max_num_frames 49 \ # Set the maximum number of frames per video to 49
--skip_frames_start 0 \ # Skip 0 frames at the start of the video
--skip_frames_end 0 \ # Skip 0 frames at the end of the video
--train_batch_size 4 \ # Set training batch size to 4
--num_train_epochs 30 \ # Total number of training epochs set to 30
--checkpointing_steps 1000 \ # Save model checkpoint every 1000 steps
--gradient_accumulation_steps 1 \ # Accumulate gradients for 1 step, updating after each batch
--learning_rate 1e-3 \ # Set learning rate to 0.001
--lr_scheduler cosine_with_restarts \ # Use cosine learning rate scheduler with restarts
--lr_warmup_steps 200 \ # Warm up the learning rate for the first 200 steps
--lr_num_cycles 1 \ # Set the number of learning rate cycles to 1
--optimizer AdamW \ # Use the AdamW optimizer
--adam_beta1 0.9 \ # Set Adam optimizer beta1 parameter to 0.9
--adam_beta2 0.95 \ # Set Adam optimizer beta2 parameter to 0.95
--max_grad_norm 1.0 \ # Set maximum gradient clipping value to 1.0
--allow_tf32 \ # Enable TF32 to speed up training
--report_to wandb # Use Weights and Biases (wandb) for logging and monitoring the training
```
## Running the Script to Start Fine-tuning
Single GPU fine-tuning:
Single Node (One GPU or Multi GPU) fine-tuning:
```shell
bash finetune_single_rank.sh
```
Multi-GPU fine-tuning:
Multi-Node fine-tuning:
```shell
bash finetune_multi_rank.sh # Needs to be run on each node
@ -147,5 +122,5 @@ bash finetune_multi_rank.sh # Needs to be run on each node
but regular fine-tuning without such tokens also works.
+ The original repository used `lora_alpha` set to 1. We found this value ineffective across multiple runs, likely due
to differences in the backend and training setup. Our recommendation is to set `lora_alpha` equal to rank or rank //
2.
2.
+ We recommend using a rank of 64 or higher.

View File

@ -47,82 +47,57 @@ pip install -e .
`finetune` スクリプト設定ファイルの例:
```shell
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# このコマンドは、OOMメモリ不足エラーを防ぐために、CUDAメモリ割り当てを拡張セグメントに設定します。
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu # 複数のGPUで `accelerate` を使用してトレーニングを開始します。指定された設定ファイルを使用します。
train_cogvideox_lora.py # LoRA微調整用に CogVideoX モデルをトレーニングするスクリプトです。
--pretrained_model_name_or_path THUDM/CogVideoX-2b # 事前学習済みモデルのパスです。
--cache_dir ~/.cache # Hugging Faceからダウンロードされたモデルとデータセットのキャッシュディレクトリです。
--enable_tiling # VAEタイル化機能を有効にし、メモリ使用量を削減します。
--enable_slicing # VAEスライス機能を有効にして、チャネルでのスライス処理を行い、メモリを節約します。
--instance_data_root ~/disney/ # インスタンスデータのルートディレクトリです。
--caption_column prompts.txt # テキストプロンプトが含まれているファイルや列を指定します。
--video_column videos.txt # ビデオパスが含まれているファイルや列を指定します。
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" # 検証用のプロンプトを指定します。複数のプロンプトを指定するには `:::` 区切り文字を使用します。
--validation_prompt_separator ::: # 検証プロンプトの区切り文字を `:::` に設定します。
--num_validation_videos 1 # 検証中に生成するビデオの数を1に設定します。
--validation_epochs 2 # 何エポックごとに検証を行うかを2に設定します。
--seed 3407 # ランダムシードを3407に設定し、トレーニングの再現性を確保します。
--rank 128 # LoRAの更新マトリックスの次元を128に設定します。
--mixed_precision bf16 # 混合精度トレーニングを `bf16` (bfloat16) に設定します。
--output_dir cogvideox-lora-single-gpu # 出力ディレクトリを指定します。
--height 480 # 入力ビデオの高さを480ピクセルに設定します。
--width 720 # 入力ビデオの幅を720ピクセルに設定します。
--fps 8 # 入力ビデオのフレームレートを8 fpsに設定します。
--max_num_frames 49 # 入力ビデオの最大フレーム数を49に設定します。
--skip_frames_start 0 # 各ビデオの最初のフレームをスキップしません。
--skip_frames_end 0 # 各ビデオの最後のフレームをスキップしません。
--train_batch_size 1 # トレーニングバッチサイズを1に設定します。
--num_train_epochs 10 # トレーニングのエポック数を10に設定します。
--checkpointing_steps 500 # 500ステップごとにチェックポイントを保存します。
--gradient_accumulation_steps 1 # 1ステップごとに勾配を蓄積して更新します。
--learning_rate 1e-4 # 初期学習率を1e-4に設定します。
--optimizer AdamW # AdamWオプティマイザーを使用します。
--adam_beta1 0.9 # Adamのbeta1パラメータを0.9に設定します。
--adam_beta2 0.95 # Adamのbeta2パラメータを0.95に設定します。
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # accelerateを使用してmulti-GPUトレーニングを起動、設定ファイルはaccelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # LoRAの微調整用のトレーニングスクリプトtrain_cogvideox_lora.pyを実行
--gradient_checkpointing \ # メモリ使用量を減らすためにgradient checkpointingを有効化
--pretrained_model_name_or_path $MODEL_PATH \ # 事前学習済みモデルのパスを$MODEL_PATHで指定
--cache_dir $CACHE_PATH \ # モデルファイルのキャッシュディレクトリを$CACHE_PATHで指定
--enable_tiling \ # メモリ節約のためにタイル処理を有効化し、動画をチャンク分けして処理
--enable_slicing \ # 入力をスライスしてさらにメモリ最適化
--instance_data_root $DATASET_PATH \ # データセットのパスを$DATASET_PATHで指定
--caption_column prompts.txt \ # トレーニングで使用する動画の説明ファイルをprompts.txtで指定
--video_column videos.txt \ # トレーニングで使用する動画のパスファイルをvideos.txtで指定
--validation_prompt "" \ # トレーニング中に検証用の動画を生成する際のプロンプト
--validation_prompt_separator ::: \ # 検証プロンプトの区切り文字を:::に設定
--num_validation_videos 1 \ # 各検証ラウンドで1本の動画を生成
--validation_epochs 100 \ # 100エポックごとに検証を実施
--seed 42 \ # 再現性を保証するためにランダムシードを42に設定
--rank 128 \ # LoRAのパラメータのランクを128に設定
--lora_alpha 64 \ # LoRAのalphaパラメータを64に設定し、LoRAの学習率を調整
--mixed_precision bf16 \ # bf16混合精度でトレーニングし、メモリを節約
--output_dir $OUTPUT_PATH \ # モデルの出力ディレクトリを$OUTPUT_PATHで指定
--height 480 \ # 動画の高さを480ピクセルに設定
--width 720 \ # 動画の幅を720ピクセルに設定
--fps 8 \ # 動画のフレームレートを1秒あたり8フレームに設定
--max_num_frames 49 \ # 各動画の最大フレーム数を49に設定
--skip_frames_start 0 \ # 動画の最初のフレームを0スキップ
--skip_frames_end 0 \ # 動画の最後のフレームを0スキップ
--train_batch_size 4 \ # トレーニングのバッチサイズを4に設定
--num_train_epochs 30 \ # 総トレーニングエポック数を30に設定
--checkpointing_steps 1000 \ # 1000ステップごとにモデルのチェックポイントを保存
--gradient_accumulation_steps 1 \ # 1ステップの勾配累積を行い、各バッチ後に更新
--learning_rate 1e-3 \ # 学習率を0.001に設定
--lr_scheduler cosine_with_restarts \ # リスタート付きのコサイン学習率スケジューラを使用
--lr_warmup_steps 200 \ # トレーニングの最初の200ステップで学習率をウォームアップ
--lr_num_cycles 1 \ # 学習率のサイクル数を1に設定
--optimizer AdamW \ # AdamWオプティマイザーを使用
--adam_beta1 0.9 \ # Adamオプティマイザーのbeta1パラメータを0.9に設定
--adam_beta2 0.95 \ # Adamオプティマイザーのbeta2パラメータを0.95に設定
--max_grad_norm 1.0 \ # 勾配クリッピングの最大値を1.0に設定
--allow_tf32 \ # トレーニングを高速化するためにTF32を有効化
--report_to wandb # Weights and Biasesを使用してトレーニングの記録とモニタリングを行う
```
## 微調整を開始
単一GPU微調整
単一マシン (シングルGPU、マルチGPU) での微調整:
```shell
bash finetune_single_rank.sh
```
複数GPU微調整
複数マシン・マルチGPUでの微調整
```shell
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。

View File

@ -44,115 +44,60 @@ pip install -e .
+ accelerate_config_machine_multi.yaml 适合多GPU使用
+ accelerate_config_machine_single.yaml 适合单GPU使用
`finetune` 脚本配置文件如下:
`finetune` 脚本配置文件如下
```shell
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# 这条命令设置了 PyTorch 的 CUDA 内存分配策略,将显存扩展为段式内存管理,以防止 OOMOut of Memory错误。
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
# 使用 Accelerate 启动训练,指定配置文件 `accelerate_config_machine_single.yaml`,并使用多 GPU。
train_cogvideox_lora.py \
# 这是你要执行的训练脚本,用于 LoRA 微调 CogVideoX 模型。
--pretrained_model_name_or_path THUDM/CogVideoX-2b \
# 预训练模型的路径,指向你要微调的 CogVideoX-5b 模型。
--cache_dir ~/.cache \
# 模型缓存的目录,用于存储从 Hugging Face 下载的模型和数据集。
--enable_tiling \
# 启用 VAE tiling 功能,通过将图像划分成更小的区块处理,减少显存占用。
--enable_slicing \
# 启用 VAE slicing 功能,将图像在通道上切片处理,以节省显存。
--instance_data_root ~/disney/ \
# 实例数据的根目录,训练时使用的数据集文件夹。
--caption_column prompts.txt \
# 用于指定包含实例提示(文本描述)的列或文件,在本例中为 `prompts.txt` 文件。
--video_column videos.txt \
# 用于指定包含视频路径的列或文件,在本例中为 `videos.txt` 文件。
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \
# 用于验证的提示语,多个提示语用指定分隔符(例如 `:::`)分开。
--validation_prompt_separator ::: \
# 验证提示语的分隔符,在此设置为 `:::`
--num_validation_videos 1 \
# 验证期间生成的视频数量,设置为 1。
--validation_epochs 2 \
# 每隔多少个 epoch 运行一次验证,设置为每 2 个 epoch 验证一次。
--seed 3407 \
# 设置随机数种子,确保训练的可重复性,设置为 3407。
--rank 128 \
# LoRA 更新矩阵的维度,控制 LoRA 层的参数大小,设置为 128。
--mixed_precision bf16 \
# 使用混合精度训练,设置为 `bf16`bfloat16可以减少显存占用并加速训练。
--output_dir cogvideox-lora-single-gpu \
# 输出目录,存放模型预测结果和检查点。
--height 480 \
# 输入视频的高度,所有视频将被调整到 480 像素。
--width 720 \
# 输入视频的宽度,所有视频将被调整到 720 像素。
--fps 8 \
# 输入视频的帧率,所有视频将以每秒 8 帧处理。
--max_num_frames 49 \
# 输入视频的最大帧数,视频将被截取到最多 49 帧。
--skip_frames_start 0 \
# 每个视频从头部开始跳过的帧数,设置为 0表示不跳过帧。
--skip_frames_end 0 \
# 每个视频从尾部跳过的帧数,设置为 0表示不跳过尾帧。
--train_batch_size 1 \
# 训练的批次大小,每个设备的训练批次设置为 1。
--num_train_epochs 10 \
# 训练的总 epoch 数,设置为 10。
--checkpointing_steps 500 \
# 每经过 500 步保存一次检查点。
--gradient_accumulation_steps 1 \
# 梯度累积步数,表示每进行 1 步才进行一次梯度更新。
--learning_rate 1e-4 \
# 初始学习率,设置为 1e-4。
--optimizer AdamW \
# 优化器类型,选择 AdamW 优化器。
--adam_beta1 0.9 \
# Adam 优化器的 beta1 参数,设置为 0.9。
--adam_beta2 0.95 \
# Adam 优化器的 beta2 参数,设置为 0.95。
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # 使用 accelerate 启动多GPU训练配置文件为 accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # 运行的训练脚本为 train_cogvideox_lora.py用于在 CogVideoX 模型上进行 LoRA 微调
--gradient_checkpointing \ # 启用梯度检查点功能,以减少显存使用
--pretrained_model_name_or_path $MODEL_PATH \ # 预训练模型路径,通过 $MODEL_PATH 指定
--cache_dir $CACHE_PATH \ # 模型缓存路径,由 $CACHE_PATH 指定
--enable_tiling \ # 启用tiling技术以分片处理视频节省显存
--enable_slicing \ # 启用slicing技术将输入切片以进一步优化内存
--instance_data_root $DATASET_PATH \ # 数据集路径,由 $DATASET_PATH 指定
--caption_column prompts.txt \ # 指定用于训练的视频描述文件,文件名为 prompts.txt
--video_column videos.txt \ # 指定用于训练的视频路径文件,文件名为 videos.txt
--validation_prompt "" \ # 验证集的提示语 (prompt),用于在训练期间生成验证视频
--validation_prompt_separator ::: \ # 设置验证提示语的分隔符为 :::
--num_validation_videos 1 \ # 每个验证回合生成 1 个视频
--validation_epochs 100 \ # 每 100 个训练epoch进行一次验证
--seed 42 \ # 设置随机种子为 42以保证结果的可复现性
--rank 128 \ # 设置 LoRA 参数的秩 (rank) 为 128
--lora_alpha 64 \ # 设置 LoRA 的 alpha 参数为 64用于调整LoRA的学习率
--mixed_precision bf16 \ # 使用 bf16 混合精度进行训练,减少显存使用
--output_dir $OUTPUT_PATH \ # 指定模型输出目录,由 $OUTPUT_PATH 定义
--height 480 \ # 视频高度为 480 像素
--width 720 \ # 视频宽度为 720 像素
--fps 8 \ # 视频帧率设置为 8 帧每秒
--max_num_frames 49 \ # 每个视频的最大帧数为 49 帧
--skip_frames_start 0 \ # 跳过视频开头的帧数为 0
--skip_frames_end 0 \ # 跳过视频结尾的帧数为 0
--train_batch_size 4 \ # 训练时的 batch size 设置为 4
--num_train_epochs 30 \ # 总训练epoch数为 30
--checkpointing_steps 1000 \ # 每 1000 步保存一次模型检查点
--gradient_accumulation_steps 1 \ # 梯度累计步数为 1即每个 batch 后都会更新梯度
--learning_rate 1e-3 \ # 学习率设置为 0.001
--lr_scheduler cosine_with_restarts \ # 使用带重启的余弦学习率调度器
--lr_warmup_steps 200 \ # 在训练的前 200 步进行学习率预热
--lr_num_cycles 1 \ # 学习率周期设置为 1
--optimizer AdamW \ # 使用 AdamW 优化器
--adam_beta1 0.9 \ # 设置 Adam 优化器的 beta1 参数为 0.9
--adam_beta2 0.95 \ # 设置 Adam 优化器的 beta2 参数为 0.95
--max_grad_norm 1.0 \ # 最大梯度裁剪值设置为 1.0
--allow_tf32 \ # 启用 TF32 以加速训练
--report_to wandb # 使用 Weights and Biases 进行训练记录与监控
```
## 运行脚本,开始微调
单卡微调:
单机(单卡,多卡)微调:
```shell
bash finetune_single_rank.sh
```
多卡微调:
多机多卡微调:
```shell
bash finetune_multi_rank.sh #需要在每个节点运行

View File

@ -3,10 +3,11 @@
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-multi-gpu"
export OUTPUT_PATH="cogvideox-lora-multi-node"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# max batch-size is 2.
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
train_cogvideox_lora.py \
--gradient_checkpointing \
@ -17,12 +18,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--video_column videos.txt \
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 2 \
--seed 3407 \
--validation_epochs 100 \
--seed 42 \
--rank 128 \
--lora_alpha 64 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
@ -32,10 +34,19 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 10 \
--checkpointing_steps 500 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-4 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb

View File

@ -3,17 +3,14 @@
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-multi-gpu"
export OUTPUT_PATH="cogvideox-lora-single-node"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# --use_8bit_adam is necessary for CogVideoX-5B-I2V
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
train_cogvideox_lora.py \
--gradient_checkpointing \
--use_8bit_adam \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
@ -21,12 +18,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--video_column videos.txt \
--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 2 \
--seed 3407 \
--validation_epochs 100 \
--seed 42 \
--rank 128 \
--lora_alpha 64 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
@ -36,10 +34,19 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 10 \
--checkpointing_steps 500 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-4 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb

View File

@ -308,7 +308,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
return None
if self.pos_embedding is not None:
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]]
else:
return None
def attention_fn(
self,