Compare commits

...

173 Commits
v1.0 ... main

Author SHA1 Message Date
OleehyO
5ab1e2449f
Merge pull request #719 from holma91/fix-lora-scale
fix scale bug
2025-03-25 18:59:49 +08:00
OleehyO
a01ffd9aba
Update cli_demo.py 2025-03-25 18:59:11 +08:00
OleehyO
9be282d461
Merge branch 'main' into fix-lora-scale 2025-03-25 18:58:33 +08:00
Yuxuan Zhang
c624cb0d91
Merge pull request #743 from THUDM/CogVideoX_dev
Format
2025-03-24 11:30:17 +08:00
Yuxuan Zhang
39c6562dc8 format 2025-03-22 15:14:06 +08:00
Yuxuan Zhang
b9b0539dbe update 2025-03-15 14:27:12 +08:00
Yuxuan Zhang
129c375c85
Merge pull request #729 from zhuhz22/riflex
Add friendly link of RIFLEx
2025-03-03 10:57:29 +08:00
Yuxuan Zhang
536b705105
Merge pull request #730 from THUDM/CogVideoX_dev
Update wechat.jpg
2025-03-03 10:56:19 +08:00
Yuxuan Zhang
a691a6dd35 Update wechat.jpg 2025-03-03 10:52:22 +08:00
zhuhz22
6454293a1d add friendly link of RIFLEx 2025-03-01 22:06:39 +08:00
Yuxuan Zhang
887a4c7365
Merge pull request #726 from THUDM/CogVideoX_dev
Readme of ddim inverse
2025-03-01 09:29:09 +08:00
Yuxuan Zhang
a494fa50cd Merge branch 'CogVideoX_dev' of https://github.com/THUDM/CogVideo into CogVideoX_dev 2025-02-27 17:33:23 +08:00
Yuxuan Zhang
4fb6766d7c fix import decord error 2025-02-27 17:33:20 +08:00
Yuxuan Zhang
8d90381ba8
Merge pull request #722 from THUDM/main
merge
2025-02-27 17:32:24 +08:00
Yuxuan Zhang
eb66c9c6dc
Merge pull request #709 from LittleNyima/feature/ddim-inversion
Implement DDIM Inversion for CogVideoX
2025-02-27 13:24:24 +08:00
LittleNyima
2c33c0982b fix import order and deprecate for CVX 2B models 2025-02-26 15:54:58 +08:00
LittleNyima
d6bb910697
Merge branch 'THUDM:main' into feature/ddim-inversion 2025-02-26 15:22:08 +08:00
holma91
84766d02e8 fix scale bug 2025-02-24 20:08:27 +01:00
LittleNyima
e0bf395458
make the style of argparser consistent with repo 2025-02-23 19:41:21 +08:00
Yuxuan Zhang
e44c9f2c83
Merge pull request #716 from THUDM/CogVideoX_dev
Update gitignore patterns and project dependencies
2025-02-22 17:09:50 +08:00
OleehyO
5be6c0512f Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev 2025-02-22 06:06:03 +00:00
OleehyO
4dac252c63 [chore] Update gitignore patterns and project dependencies 2025-02-22 06:03:53 +00:00
LittleNyima
250a0bce45 stable version 2025-02-20 05:03:15 +00:00
LittleNyima
58d66c8a08
Implement an unverified version that should be further tested 2025-02-20 01:39:12 +08:00
LittleNyima
dd76b2b9ea Initialize DDIM Inversion script 2025-02-18 09:50:55 +00:00
Yuxuan Zhang
34c6ba22ab
Merge pull request #682 from THUDM/main
Synchronize two branches.
2025-01-22 09:49:34 +08:00
Yuxuan Zhang
bbe909d7f7
Merge pull request #678 from THUDM/CogVideoX_dev
docs: clarify frame number requirements for CogVideoX models
2025-01-22 09:47:24 +08:00
Yuxuan Zhang
ea994c75c2
Merge pull request #652 from erfanasgari21/moviepy-v2
Update code and requirements to support Moviepy v2
2025-01-21 22:29:15 +08:00
Yuxuan Zhang
aa12ed37f5
Merge branch 'main' into moviepy-v2 2025-01-20 21:46:07 +08:00
OleehyO
d9e2a415e8 fix: fix resolution handling for different model types 2025-01-20 09:48:17 +00:00
OleehyO
0e26f54cbe docs: clarify frame number requirements for CogVideoX models
Specify that frame numbers must be:
- 16N + 1 (N <= 10) for CogVideoX1.5-5B models
- 8N + 1 (N <= 6) for CogVideoX-2B/5B models
2025-01-20 09:43:45 +00:00
Yuxuan Zhang
c1ca70ba67
Merge pull request #654 from THUDM/CogVideoX_dev
Support SFT using ZeRO
2025-01-20 11:15:50 +08:00
OleehyO
bf73742c05 docs: enhance CLI demo documentation 2025-01-16 09:34:52 +00:00
OleehyO
bf9c351a10 deps: upgrade diffusers to >=0.32.1 2025-01-16 09:08:44 +00:00
OleehyO
0e78f20629 Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev 2025-01-14 04:00:11 +00:00
Yuxuan Zhang
4615479b51 move to tools 2025-01-14 11:33:02 +08:00
Yuxuan Zhang
7993670957 zero_to_bf16 2025-01-14 11:31:25 +08:00
OleehyO
4878edd0cf fix: correct do_validation argument parsing 2025-01-13 12:48:21 +00:00
Yuxuan Zhang
78275b0480 add comment of bash scripts 2025-01-13 20:02:06 +08:00
OleehyO
455b44a7b5 chore: code cleanup and parameter optimization
- Remove redundant comments and debug information
- Adjust default parameters in training scripts
- Clean up code in lora_trainer and trainer implementations
2025-01-13 11:56:28 +00:00
OleehyO
954ba28d3c Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev 2025-01-13 11:48:24 +00:00
OleehyO
4f1cc66815 fix: correct LoRA loading and resolution dimensions
- Fix LoRA loading by specifying 'transformer' component
- Swap width/height order in RESOLUTION_MAP to match actual usage
2025-01-13 10:49:46 +00:00
zR
1534bf33eb add pipeline 2025-01-12 19:27:21 +08:00
OleehyO
86a0226f80 Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev 2025-01-12 08:52:07 +00:00
OleehyO
70c899f444 chore: update default training configurations 2025-01-12 08:50:15 +00:00
OleehyO
b362663679 fix: normalize image tensors in I2VDataset 2025-01-12 06:01:48 +00:00
OleehyO
30ba1085ff Merge remote-tracking branch 'upstream/main' into dev 2025-01-12 05:58:07 +00:00
OleehyO
3252614569 Add pydantic dependency 2025-01-12 05:56:24 +00:00
OleehyO
f66f1647e2
Merge pull request #657 from ZGCTroy/main
fix bug of i2v finetune
2025-01-12 13:55:12 +08:00
OleehyO
f5169385bd docs: add SFT support documentation in multilingual README 2025-01-12 05:53:13 +00:00
OleehyO
795dd144a4 Rename lora training scripts as ddp 2025-01-12 05:36:32 +00:00
OleehyO
fdb9820949 feat: support DeepSpeed ZeRO-3 and optimize peak memory usage
- Add DeepSpeed ZeRO-3 configuration support
- Optimize memory usage during training
- Rename training scripts to reflect ZeRO usage
- Update related configuration files and trainers
2025-01-12 05:33:56 +00:00
Zheng Guang Cong
09a49d3546
fix bug of i2v; video is already 0-255
video is already 0-255 and should not be multiplied 255 any more
2025-01-11 17:29:27 +08:00
Zheng Guang Cong
cd861bbe1e
Update i2v_dataset.py
image should also be transformed to [-1, 1]
2025-01-11 17:24:35 +08:00
Zheng Guang Cong
35383e2db3
fix potential bug of i2v
Image value is in [0, 255] and should be transformed into [-1, 1], similar to video.
2025-01-11 17:08:25 +08:00
zR
7dc8516bcb add comment as #653 2025-01-11 12:53:32 +08:00
OleehyO
2f275e82b5 Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev 2025-01-11 02:16:09 +00:00
OleehyO
caa24bdc36 feat: add SFT support with ZeRO optimization strategies
- Add SFT (Supervised Fine-Tuning) trainers for all model variants:
  - CogVideoX I2V and T2V
  - CogVideoX-1.5 I2V and T2V
- Add DeepSpeed ZeRO configuration files:
  - ZeRO-2 with and without CPU offload
  - ZeRO-3 with and without CPU offload
- Add base accelerate config for distributed training
- Update trainer.py to support SFT training mode

This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
2025-01-11 02:13:32 +00:00
OleehyO
e213b6c083 fix: pad latent frames to match patch_size_t requirements 2025-01-11 02:08:07 +00:00
Erfan Asgari
70ca65300c
upgrade to moviepy v2 2025-01-11 00:18:24 +03:30
OleehyO
f6d722cec7 fix: remove copying first video frame as conditioning image 2025-01-09 15:52:51 +00:00
OleehyO
07766001f6 feat(dataset): pad short videos by repeating last frame
When loading videos with fewer frames than max_num_frames, repeat the last
frame to reach the required length instead of failing. This ensures consistent
tensor dimensions across the dataset while preserving as much original video
content as possible.
2025-01-08 02:14:56 +00:00
Yuxuan Zhang
8f1829f1cd
Merge pull request #642 from THUDM/CogVideoX_dev
New Lora 20250108
2025-01-08 09:51:39 +08:00
zR
045e1b308b readme 2025-01-08 09:50:08 +08:00
OleehyO
249fadfb76 docs: add hardware requirements for model training
Add a table in README files showing hardware requirements for training
different CogVideoX models, including:
- Memory requirements for each model variant
- Supported training types (LoRA)
- Training resolutions
- Mixed precision settings

Updated in all language versions (EN/ZH/JA).
2025-01-08 01:39:37 +00:00
OleehyO
10de04fc08 perf: cast VAE and text encoder to target dtype before precomputing cache
Before precomputing the latent cache and text embeddings, cast the VAE and
text encoder to the target training dtype (fp16/bf16) instead of keeping them
in fp32. This reduces memory usage during the precomputation phase.

The change occurs in prepare_dataset() where the models are moved to device
and cast to weight_dtype before being used to generate the cache.
2025-01-08 01:38:13 +00:00
OleehyO
0e21d41b12 Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev 2025-01-07 09:51:48 +00:00
OleehyO
392e37021a Add video path to error message for better debugging 2025-01-07 09:50:21 +00:00
zR
11935892ae remove --image_column 2025-01-07 16:37:11 +08:00
OleehyO
ee1f666206 docs: update READMEs with auto first-frame extraction feature 2025-01-07 06:45:10 +00:00
OleehyO
e084a4a270 feat: auto-extract first frames as conditioning images for i2v model
When training i2v models without specifying image_column, automatically extract
and use first frames from training videos as conditioning images. This includes:

- Add load_images_from_videos() utility function to extract and cache first frames
- Update BaseI2VDataset to support auto-extraction when image_column is None
- Add validation and warning message in Args schema for i2v without image_column

The first frames are extracted once and cached to avoid repeated video loading.
2025-01-07 06:43:26 +00:00
OleehyO
96e511b413 feat: add warning for fp16 mixed precision training 2025-01-07 06:00:38 +00:00
OleehyO
36427274d6 style: format import statements across finetune module 2025-01-07 05:54:52 +00:00
zR
1789f07256 format and check fp16 for cogvideox2b 2025-01-07 13:16:18 +08:00
OleehyO
1b886326b2 Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev 2025-01-06 10:47:56 +00:00
OleehyO
9157e0cbc8 Adapt dataset for text embeddings and add noise padding
- Add text embedding support in dataset collation
- Pad 2 random noise frames at the beginning of latent space during training
2025-01-06 10:44:58 +00:00
OleehyO
49dc370de6 fix: remove pipeline hooks after validation
- Add pipe.remove_all_hooks() after validation to prevent memory leaks
- Clean up validation pipeline properly to avoid potential issues in subsequent training steps
2025-01-04 06:21:17 +00:00
OleehyO
93b906b3fb docs: clarify train_frames includes padding frame
Add docstring to train_frames field in State schema to explicitly indicate
that it includes one image padding frame
2025-01-04 06:20:25 +00:00
OleehyO
7e1ac76847 feat(cogvideox): add prompt embedding caching support
This change enables caching of prompt embeddings in the CogVideoX text-to-video
LoRA trainer, which can improve training efficiency by avoiding redundant text
encoding operations.
2025-01-04 06:17:56 +00:00
OleehyO
66e4ba2592 fix(cogvideox): add prompt embedding caching and fix frame padding
- Add support for cached prompt embeddings in dataset
- Fix bug where first frame wasn't properly padded in latent space
2025-01-04 06:16:42 +00:00
OleehyO
de5bef6611 feat(args): add train_resolution validation for video frames and resolution
- Add validation to ensure (frames - 1) is multiple of 8
- Add specific resolution check (480x720) for cogvideox-5b models
- Add error handling for invalid resolution format
2025-01-04 06:16:42 +00:00
OleehyO
ffb6ee36b4 docs: update finetune documentation in all languages 2025-01-04 06:16:42 +00:00
OleehyO
c817e7f062 chore: update default training parameters for t2v and i2v scripts 2025-01-04 06:16:42 +00:00
OleehyO
e5b8f9a2ee feat: add caching for prompt embeddings
- Add caching for prompt embeddings
- Store cached files using safetensors format
- Add cache directory structure under data_root/cache
- Optimize memory usage by moving tensors to CPU after caching
- Add debug logging for cache hits
- Add info logging for cache writes

The caching system helps reduce redundant computation and memory usage during training by:
1. Caching prompt embeddings based on prompt text hash
2. Caching encoded video latents based on video filename
3. Moving tensors to CPU after caching to free GPU memory
2025-01-04 06:16:31 +00:00
OleehyO
f731c35f70 Add unload_model function 2025-01-03 08:21:27 +00:00
zR
ce2c299c1f Update diffusion_video.py 2025-01-03 08:45:42 +08:00
zR
b080c6a010 put lora back(sat), unavailable running 2025-01-02 11:48:18 +08:00
OleehyO
a88c1ede69 feat(args): add validation for training resolution
- Add validation check to ensure number of frames is multiple of 8
- Add format validation for train_resolution string (frames x height x width)
2025-01-02 03:12:09 +00:00
OleehyO
362b7bf273 docs: update README in multiple languages 2025-01-02 03:07:34 +00:00
Yuxuan Zhang
aa240dc675
Merge pull request #632 from THUDM/CogVideoX_dev
Refactored the training code of finetune
2025-01-02 08:31:25 +08:00
OleehyO
cf2fff7e55 Merge remote-tracking branch 'upstream/main' into dev 2025-01-01 16:03:51 +00:00
OleehyO
7fa1bb48be refactor: remove deprecated training scripts 2025-01-01 15:56:14 +00:00
OleehyO
48ad178818 Reorganize training script arguments 2025-01-01 15:52:39 +00:00
三洋三洋
6ef15dd2a5 docs: update TOC and add friendly link in README files
- Update table of contents in README.md, README_ja.md and README_zh.md
- Add friendly link section to all README files
2025-01-01 15:10:55 +00:00
OleehyO
6e79472417 feat: add training launch scripts for I2V and T2V models
Add two shell scripts to simplify model training:
- accelerate_train_i2v.sh: Launch script for Image-to-Video training
- accelerate_train_t2v.sh: Launch script for Text-to-Video training

Both scripts provide comprehensive configurations for:
- Model settings
- Data pipeline
- Training parameters
- System resources
- Checkpointing
- Validation
2025-01-01 15:10:55 +00:00
OleehyO
26b87cd4ff feat(args): add validation and arg interface for training parameters
- Add field validators for model type and validation settings
- Implement command line argument parsing with argparse
- Add type hints and documentation for training parameters
- Support configuration of model, training, and validation parameters
2025-01-01 15:10:55 +00:00
OleehyO
04a60e7435 Change logger name to trainer 2025-01-01 15:10:55 +00:00
OleehyO
a001842834 feat: implement CogVideoX trainers for I2V and T2V tasks
Add and refactor trainers for CogVideoX model variants:
- Implement CogVideoXT2VLoraTrainer for text-to-video generation
- Refactor CogVideoXI2VLoraTrainer for image-to-video generation

Both trainers support LoRA fine-tuning with proper handling of:
- Model components loading and initialization
- Video encoding and batch collation
- Loss computation with noise prediction
- Validation step for generation
2025-01-01 15:10:54 +00:00
OleehyO
91d79fd9a4 feat: add schemas module for configuration and state management
Add Pydantic models to handle:
- CLI arguments and configuration (Args)
- Model components and pipeline (Components)
- Training state and parameters (State)
2025-01-01 15:10:54 +00:00
OleehyO
45d40450a1 refactor: simplify dataset implementation and add latent precomputation
- Replace bucket-based dataset with simpler resize-based implementation
- Add video latent precomputation during dataset initialization
- Improve code readability and user experience
- Remove complexity of bucket sampling for better maintainability

This change makes the codebase more straightforward and easier to use while
maintaining functionality through resize-based video processing.
2025-01-01 15:10:54 +00:00
OleehyO
6eae5c201e feat: add latent caching for video encodings
- Add caching mechanism to store VAE-encoded video latents to disk
- Cache latents in a "latent" subdirectory alongside video files
- Skip re-encoding when cached latent file exists
- Add logging for successful cache saves
- Minor code cleanup and formatting improvements

This change improves training efficiency by avoiding redundant video encoding operations.
2025-01-01 15:10:42 +00:00
OleehyO
2a6cca0656 Add type conversion and validation checks 2025-01-01 15:10:42 +00:00
OleehyO
fa4659fb2c feat(trainer): add validation functionality to Trainer class
Add validation capabilities to the Trainer class including:
- Support for validating images and videos during training
- Periodic validation based on validation_steps parameter
- Artifact logging to wandb for validation results
- Memory tracking during validation process
2025-01-01 15:10:41 +00:00
OleehyO
6971364591 Export file_utils.py 2025-01-01 15:10:41 +00:00
OleehyO
60f6a3d7ee feat: add base trainer implementation and training script
- Add Trainer base class with core training loop functionality
- Implement distributed training setup with Accelerate
- Add training script with model/trainer initialization
- Support LoRA fine-tuning with checkpointing and validation
2025-01-01 15:10:41 +00:00
OleehyO
a505f2e312 Add constants.py 2025-01-01 15:10:40 +00:00
OleehyO
78f655a9a4 Add utils 2025-01-01 15:10:40 +00:00
OleehyO
85e00a1082 feat(models): add scaffolding 2025-01-01 15:10:40 +00:00
OleehyO
918ebb5a54 feat(datasets): implement video dataset modules
- Add dataset implementations for text-to-video and image-to-video
- Include bucket sampler for efficient batch processing
- Add utility functions for data processing
- Create dataset package structure with proper initialization
2025-01-01 15:10:40 +00:00
OleehyO
e3f6def234 feat: add video frame extraction tool
Add utility script to extract first frames from videos, helping users convert T2V datasets to I2V format
2025-01-01 15:10:39 +00:00
OleehyO
7b282246dd chore: remove unused configuration files after refactoring
Delete accelerate configs, deepspeed config and host file that are no longer needed
2025-01-01 15:10:39 +00:00
OleehyO
5cb9303286 chore: update .gitignore
- Add new ignore patterns for dataset and model directories
- Update rules for development files
2025-01-01 15:10:32 +00:00
OleehyO
ba85627577 [docs] improve help messages in argument parser
Fix and clarify help documentation in parser.add_argument() to better describe command-line arguments.
2025-01-01 15:10:31 +00:00
OleehyO
2508c8353b [bugfix] fix specific resolution setting
Different models use different resolutions, for example, for the CogVideoX1.5 series models, the optimal generation resolution is 1360x768, But for CogVideoX, the best resolution is 720x480.
2025-01-01 15:10:31 +00:00
Gforky
48ac9c1066 [fix]fix typo in train_cogvideox_image_to_video_lora.py 2025-01-01 15:10:30 +00:00
Zheng Guang Cong
21693ca770 fix bugs of image-to-video without image-condition 2025-01-01 15:10:30 +00:00
三洋三洋
a6e611e354 docs: update TOC and add friendly link in README files
- Update table of contents in README.md, README_ja.md and README_zh.md
- Add friendly link section to all README files
2024-12-27 19:37:08 +08:00
Yuxuan.Zhang
7935bd58a1
Merge pull request #615 from THUDM/CogVideoX_dev
Cog video x dev
2024-12-19 12:57:56 +08:00
OleehyO
1811c50e73 [docs] improve help messages in argument parser
Fix and clarify help documentation in parser.add_argument() to better describe command-line arguments.
2024-12-18 12:30:13 +00:00
OleehyO
92a589240f [bugfix] fix specific resolution setting
Different models use different resolutions, for example, for the CogVideoX1.5 series models, the optimal generation resolution is 1360x768, But for CogVideoX, the best resolution is 720x480.
2024-12-18 12:25:43 +00:00
OleehyO
7add8f8437
Merge pull request #607 from THUDM/CogVideoX_dev
Add resolution warning
2024-12-17 09:58:10 +08:00
OleehyO
cfaca91cde Merge remote-tracking branch 'upstream/main' into dev 2024-12-16 11:38:26 +00:00
OleehyO
d3a7d2dc91 Add resolution warning 2024-12-16 11:34:51 +00:00
Yuxuan.Zhang
46098f446b
Merge pull request #603 from Gforky/fix-demo-issue
[fix]fix typo in train_cogvideox_image_to_video_lora.py
2024-12-15 22:00:41 +08:00
Gforky
5a03e6fa79 [fix]fix typo in train_cogvideox_image_to_video_lora.py 2024-12-14 16:12:57 +08:00
Yuxuan.Zhang
1605e95033
Merge pull request #599 from THUDM/CogVideoX_dev
Cog video x dev
2024-12-13 15:03:48 +08:00
OleehyO
7b4c9db6d9 Fix for CogVideoX-{2B,5B}
When loading CogVideX-{2B,5B}, `patch_size_t` is None,
which results in the `prepare_rotary_position_embeddings` function.
2024-12-13 04:02:27 +00:00
OleehyO
36f1333788 Fix for deepspeed training 2024-12-13 04:02:26 +00:00
OleehyO
4d1b9fd166 Fix for Disney video dataset 2024-12-13 04:02:21 +00:00
OleehyO
3ff9d3049d docs: change "read this in English" to "中文阅读"
Update README.md to use Chinese text for language switch link
2024-12-11 05:10:28 +00:00
Yuxuan.Zhang
496e220463
Merge pull request #585 from ZGCTroy/patch-1
fix bugs of image-to-video without image-condition
2024-12-08 19:31:59 +08:00
Zheng Guang Cong
a46d762cd9
fix bugs of image-to-video without image-condition 2024-12-06 20:14:43 +08:00
Yuxuan.Zhang
87ccd38cea
Merge pull request #567 from THUDM/main
New Finetune
2024-12-02 11:30:20 +08:00
Yuxuan.Zhang
5aa6d3a9ee
Merge pull request #515 from Gforky/fix_finetune_demo
[fix]fix deepspeed initialization issue in finetune examples
2024-12-02 11:29:42 +08:00
Yuxuan.Zhang
a094b34425
Merge pull request #565 from THUDM/CogVideoX_dev
Cog video x dev
2024-11-30 12:45:25 +08:00
zR
0fe46df21f new jobs of friendly link 2024-11-30 12:40:07 +08:00
Yuxuan.Zhang
f1a2b48974
Merge pull request #556 from THUDM/main
new announced
2024-11-27 12:11:12 +08:00
Yuxuan.Zhang
d82922cc79
Merge pull request #538 from spacegoing/fix_rope_finetune_shape
[Fix] fix rope temporal patch size
2024-11-23 21:24:39 +08:00
spacegoing
2fb763d25f [Fix] fix rope temporal patch size 2024-11-21 16:26:45 +00:00
luwen.miao
ac2f2c78f7 [fix]fix deepspeed initialization issue in finetune examples 2024-11-18 09:49:31 +00:00
Yuxuan.Zhang
2fdc59c3ce
Merge pull request #507 from THUDM/CogVideoX_dev
diffusers version
2024-11-17 21:54:47 +08:00
zR
17996f11f8 update 2024-11-16 10:06:22 +08:00
Yuxuan.Zhang
5e3e3aabe0
Merge pull request #500 from THUDM/main
Merge
2024-11-13 21:15:49 +08:00
zR
e7a35ea33b update friendly link 2024-11-13 17:06:16 +08:00
zR
cd5ceca22b fix resolution docs 2024-11-12 00:41:23 +08:00
zR
bb2cb130a0 add width and height 2024-11-12 00:17:19 +08:00
zR
2151a3bdfb update with diffusers 2024-11-11 22:41:28 +08:00
zR
68d93ce8fc fix 2024-11-09 22:51:39 +08:00
zR
155456befa update 2024-11-09 22:49:03 +08:00
zR
2475902027 friendly link 2024-11-09 22:43:02 +08:00
zR
fb806eecce update table 2024-11-09 22:29:36 +08:00
zR
c8c7b62aa1 update diffusers code 2024-11-09 22:07:32 +08:00
Yuxuan.Zhang
e2987ff565
Merge pull request #474 from THUDM/CogVideoX_dev
Fix #472 #473
2024-11-09 00:18:01 +08:00
zR
a8205b575d Update cp_enc_dec.py 2024-11-08 23:27:44 +08:00
zR
e7bcecf947 remove wrong fake_cp 2024-11-08 22:54:17 +08:00
zR
d8ee013842 add 10 second comment 2024-11-08 22:31:39 +08:00
zR
e43a7645fd Update autoencoder.py 2024-11-08 21:49:02 +08:00
zR
8e07303dbf Update inference.sh 2024-11-08 21:39:23 +08:00
zR
0c6dc7b5d5 fix #472 2024-11-08 21:37:43 +08:00
Yuxuan.Zhang
74baf5ceef
Merge pull request #470 from THUDM/CogVideoX_dev
rope
2024-11-08 13:56:43 +08:00
zR
2360393b99 3d_rope_pos_embed 2024-11-08 13:53:30 +08:00
Yuxuan.Zhang
ddd3dcd7eb
Merge pull request #469 from THUDM/CogVideoX_dev
CogVideoX1.5-SAT
2024-11-08 13:50:19 +08:00
zR
42417ad3bc 3d_rope_pos_embed 2024-11-08 13:49:58 +08:00
Yuxuan.Zhang
075fad4dae
Merge pull request #465 from DefTruth/main
[Bug] fix parallel rand device & set dynamic cfg
2024-11-08 13:32:57 +08:00
zR
abe334f196 Update test.txt 2024-11-08 12:44:54 +08:00
zR
f1f539a9da remove pic 2024-11-08 12:09:01 +08:00
zR
494296b063 Merge branch 'CogVideoX_dev' of github.com:THUDM/CogVideo into CogVideoX_dev 2024-11-08 11:52:56 +08:00
Yuxuan.Zhang
d1e45fbb86
Merge pull request #468 from THUDM/main
merge
2024-11-07 23:59:21 +08:00
zR
806a7f609f update cogvideox1.5 2024-11-07 23:43:46 +08:00
DefTruth
1e6d1bbb82 fix parallel rand device 2024-11-07 02:40:13 +00:00
zR
0ae12e3ea3 use original up.upsample 2024-11-05 22:06:05 +08:00
zR
4a3035d64e update 1105 sst test code with fake cp 2024-11-05 12:55:54 +08:00
zR
3a9af5bdd9 update with test code 2024-11-04 14:34:36 +08:00
177 changed files with 8887 additions and 5250 deletions

View File

@ -0,0 +1,28 @@
# Contribution Guide
We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.
## What We Accept
+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.
## Code Style Guide
Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
1. Install the required dependencies:
```shell
pip install ruff pre-commit
```
2. Then, run the following command:
```shell
pre-commit run --all-files
```
If your code complies with the standards, you should not see any errors.
## Naming Conventions
- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.

View File

@ -1,34 +0,0 @@
# Raise valuable PR / 提出有价值的PR
## Caution / 注意事项:
Users should keep the following points in mind when submitting PRs:
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
用户在提交PR时候应该注意以下几点:
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
2. 提出的PR应该具有针对性如果具有多个不同的想法和优化方案应该分配到不同的PR中。
## 不应该提出的PR / PRs that should not be proposed
If a developer proposes a PR about any of the following, it may be closed or Rejected.
1. those that don't describe improvement options.
2. multiple issues of different types combined in one PR.
3. The proposed PR is highly duplicative of already existing PRs.
如果开发者提出关于以下方面的PR则可能会被直接关闭或拒绝通过。
1. 没有说明改进方案的。
2. 多个不同类型的问题合并在一个PR中的。
3. 提出的PR与已经存在的PR高度重复的。
# 检查您的PR
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题

27
.gitignore vendored
View File

@ -8,3 +8,30 @@ logs/
.idea
output*
test*
venv
**/.swp
**/*.log
**/*.debug
**/.vscode
**/*debug*
**/.gitignore
**/finetune/*-lora-*
**/finetune/Disney-*
**/wandb
**/results
**/*.mp4
**/validation_set
CogVideo-1.0
**/**foo**
**/sat
**/train_results
**/train_res*
**/**foo**
**/sat
**/train_results
**/train_res*
**/uv.lock

19
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,19 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
- id: ruff-format
args: [--config=pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements

130
README.md
View File

@ -22,7 +22,14 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
## Project Updates
- 🔥🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
- 🔥🔥 **News**: ```2025/03/24```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
- 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py).
- 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
- 🔥 **News**: ```2024/11/08```: We have released the CogVideoX1.5 model. CogVideoX1.5 is an upgraded version of the open-source model CogVideoX.
The CogVideoX1.5-5B series supports 10-second videos with higher resolution, and CogVideoX1.5-5B-I2V supports video generation at any resolution.
The SAT code has already been updated, while the diffusers version is still under adaptation. Download the SAT version code [here](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT).
- 🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
4090 GPU, [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory), has been released. It supports
fine-tuning with multiple resolutions. Feel free to use it!
- 🔥 **News**: ```2024/10/10```: We have updated our technical report. Please
@ -40,11 +47,11 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
model [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), used in the training process of
CogVideoX to convert video data into text descriptions, has been open-sourced. Welcome to download and use it.
- 🔥 ```2024/8/27```: We have open-sourced a larger model in the CogVideoX series, **CogVideoX-5B**. We have
significantly optimized the model's inference performance, greatly lowering the inference threshold. You can run *
*CogVideoX-2B** on older GPUs like `GTX 1080TI`, and **CogVideoX-5B** on desktop GPUs like `RTX 3060`. Please strictly
significantly optimized the model's inference performance, greatly lowering the inference threshold.
You can run **CogVideoX-2B** on older GPUs like `GTX 1080TI`, and **CogVideoX-5B** on desktop GPUs like `RTX 3060`. Please strictly
follow the [requirements](requirements.txt) to update and install dependencies, and refer
to [cli_demo](inference/cli_demo.py) for inference code. Additionally, the open-source license for the **CogVideoX-2B
** model has been changed to the **Apache 2.0 License**.
to [cli_demo](inference/cli_demo.py) for inference code. Additionally, the open-source license for
the **CogVideoX-2B** model has been changed to the **Apache 2.0 License**.
- 🔥 ```2024/8/6```: We have open-sourced **3D Causal VAE**, used for **CogVideoX-2B**, which can reconstruct videos with
almost no loss.
- 🔥 ```2024/8/6```: We have open-sourced the first model of the CogVideoX series video generation models, **CogVideoX-2B
@ -57,19 +64,24 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
Jump to a specific section:
- [Quick Start](#Quick-Start)
- [Quick Start](#quick-start)
- [Prompt Optimization](#prompt-optimization)
- [SAT](#sat)
- [Diffusers](#Diffusers)
- [CogVideoX-2B Video Works](#cogvideox-2b-gallery)
- [Introduction to the CogVideoX Model](#Model-Introduction)
- [Full Project Structure](#project-structure)
- [Diffusers](#diffusers)
- [Gallery](#gallery)
- [CogVideoX-5B](#cogvideox-5b)
- [CogVideoX-2B](#cogvideox-2b)
- [Model Introduction](#model-introduction)
- [Friendly Links](#friendly-links)
- [Project Structure](#project-structure)
- [Quick Start with Colab](#quick-start-with-colab)
- [Inference](#inference)
- [SAT](#sat)
- [finetune](#finetune)
- [sat](#sat-1)
- [Tools](#tools)
- [Introduction to CogVideo(ICLR'23) Model](#cogvideoiclr23)
- [Citations](#Citation)
- [Open Source Project Plan](#Open-Source-Project-Plan)
- [Model License](#Model-License)
- [CogVideo(ICLR'23)](#cogvideoiclr23)
- [Citation](#citation)
- [Model-License](#model-license)
## Quick Start
@ -169,82 +181,93 @@ models we currently offer, along with their foundational information.
<table style="border-collapse: collapse; width: 100%;">
<tr>
<th style="text-align: center;">Model Name</th>
<th style="text-align: center;">CogVideoX1.5-5B (Latest)</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V (Latest)</th>
<th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V</th>
</tr>
<tr>
<td style="text-align: center;">Model Description</td>
<td style="text-align: center;">Entry-level model, balancing compatibility. Low cost for running and secondary development.</td>
<td style="text-align: center;">Larger model with higher video generation quality and better visual effects.</td>
<td style="text-align: center;">CogVideoX-5B image-to-video version.</td>
<td style="text-align: center;">Release Date</td>
<th style="text-align: center;">November 8, 2024</th>
<th style="text-align: center;">November 8, 2024</th>
<th style="text-align: center;">August 6, 2024</th>
<th style="text-align: center;">August 27, 2024</th>
<th style="text-align: center;">September 19, 2024</th>
</tr>
<tr>
<td style="text-align: center;">Video Resolution</td>
<td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;"> Min(W, H) = 768 <br> 768 ≤ Max(W, H) ≤ 1360 <br> Max(W, H) % 16 = 0 </td>
<td colspan="3" style="text-align: center;">720 * 480</td>
</tr>
<tr>
<td style="text-align: center;">Number of Frames</td>
<td colspan="2" style="text-align: center;">Should be <b>16N + 1</b> where N <= 10 (default 81)</td>
<td colspan="3" style="text-align: center;">Should be <b>8N + 1</b> where N <= 6 (default 49)</td>
</tr>
<tr>
<td style="text-align: center;">Inference Precision</td>
<td style="text-align: center;"><b>FP16*(recommended)</b>, BF16, FP32, FP8*, INT8, not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16 (recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
<td style="text-align: center;"><b>FP16*(Recommended)</b>, BF16, FP32, FP8*, INT8, Not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
</tr>
<tr>
<td style="text-align: center;">Single GPU Memory Usage<br></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: from 4GB* </b><br><b>diffusers INT8 (torchao): from 3.6GB*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16: from 5GB* </b><br><b>diffusers INT8 (torchao): from 4.4GB*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 76GB <br><b>diffusers BF16: from 10GB*</b><br><b>diffusers INT8(torchao): from 7GB*</b></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB minimum* </b><br><b>diffusers INT8 (torchao): 3.6GB minimum*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB minimum* </b><br><b>diffusers INT8 (torchao): 4.4GB minimum* </b></td>
</tr>
<tr>
<td style="text-align: center;">Multi-GPU Inference Memory Usage</td>
<td style="text-align: center;">Multi-GPU Memory Usage</td>
<td colspan="2" style="text-align: center;"><b>BF16: 24GB* using diffusers</b><br></td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
</tr>
<tr>
<td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td>
<td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
<td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td>
<td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td>
</tr>
<tr>
<td style="text-align: center;">Fine-tuning Precision</td>
<td style="text-align: center;"><b>FP16</b></td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr>
<tr>
<td style="text-align: center;">Fine-tuning Memory Usage</td>
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
</tr>
<tr>
<td style="text-align: center;">Prompt Language</td>
<td colspan="3" style="text-align: center;">English*</td>
<td colspan="5" style="text-align: center;">English*</td>
</tr>
<tr>
<td style="text-align: center;">Maximum Prompt Length</td>
<td style="text-align: center;">Prompt Token Limit</td>
<td colspan="2" style="text-align: center;">224 Tokens</td>
<td colspan="3" style="text-align: center;">226 Tokens</td>
</tr>
<tr>
<td style="text-align: center;">Video Length</td>
<td colspan="3" style="text-align: center;">6 Seconds</td>
<td colspan="2" style="text-align: center;">5 seconds or 10 seconds</td>
<td colspan="3" style="text-align: center;">6 seconds</td>
</tr>
<tr>
<td style="text-align: center;">Frame Rate</td>
<td colspan="3" style="text-align: center;">8 Frames / Second</td>
</tr>
<tr>
<td style="text-align: center;">Video Resolution</td>
<td colspan="3" style="text-align: center;">720 x 480, no support for other resolutions (including fine-tuning)</td>
<td colspan="2" style="text-align: center;">16 frames / second </td>
<td colspan="3" style="text-align: center;">8 frames / second </td>
</tr>
<tr>
<td style="text-align: center;">Position Encoding</td>
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr>
<tr>
<td style="text-align: center;">Download Link (Diffusers)</td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
</tr>
<tr>
<td style="text-align: center;">Download Link (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
</tr>
</table>
@ -271,21 +294,20 @@ pipe.vae.enable_tiling()
used to quantize the text encoder, transformer, and VAE modules to reduce the memory requirements of CogVideoX. This
allows the model to run on free T4 Colabs or GPUs with smaller memory! Also, note that TorchAO quantization is fully
compatible with `torch.compile`, which can significantly improve inference speed. FP8 precision must be used on
devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao`, `diffusers`, and `accelerate`
Python packages. CUDA 12.4 is recommended.
devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao` Python packages. CUDA 12.4 is recommended.
+ The inference speed tests also used the above memory optimization scheme. Without memory optimization, inference speed
increases by about 10%. Only the `diffusers` version of the model supports quantization.
+ The model only supports English input; other languages can be translated into English for use via large model
refinement.
+ The memory usage of model fine-tuning is tested in an `8 * H100` environment, and the program automatically
uses `Zero 2` optimization. If a specific number of GPUs is marked in the table, that number or more GPUs must be used
for fine-tuning.
## Friendly Links
We highly welcome contributions from the community and actively contribute to the open-source community. The following
works have already been adapted for CogVideoX, and we invite everyone to use them:
+ [RIFLEx-CogVideoX](https://github.com/thu-ml/RIFLEx)
RIFLEx extends the video with just one line of code: `freq[k-1]=(2np.pi)/(Ls)`. The framework not only supports training-free inference, but also offers models fine-tuned based on CogVideoX. By fine-tuning the model for just 1,000 steps on original-length videos, RIFLEx significantly enhances its length extrapolation capability.
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun is a modified pipeline based on the
CogVideoX architecture, supporting flexible resolutions and multiple launch methods.
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): A separate repository for CogVideo's Gradio Web UI, which
@ -312,6 +334,10 @@ works have already been adapted for CogVideoX, and we invite everyone to use the
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): DiffSynth Studio is a diffusion engine. It has
restructured the architecture, including text encoders, UNet, VAE, etc., enhancing computational performance while
maintaining compatibility with open-source community models. The framework has been adapted for CogVideoX.
+ [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): A simple ControlNet module code that includes the CogVideoX model.
+ [VideoTuna](https://github.com/VideoVerses/VideoTuna): VideoTuna is the first repo that integrates multiple AI video generation models for text-to-video, image-to-video, text-to-image generation.
+ [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): An identity-preserving text-to-video generation model, bases on CogVideoX-5B, which keep the face consistent in the generated video by frequency decomposition.
+ [A Step by Step Tutorial](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): A step-by-step guide on installing and optimizing the CogVideoX1.5-5B-I2V model in Windows and cloud environments. Special thanks to the [FurkanGozukara](https://github.com/FurkanGozukara) for his effort and support!
## Project Structure
@ -420,9 +446,7 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
}
```
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
## License Agreement
## Model-License
The code in this repository is released under the [Apache 2.0 License](LICENSE).

View File

@ -1,6 +1,6 @@
# CogVideo & CogVideoX
[Read this in English](./README_zh.md)
[Read this in English](./README.md)
[中文阅读](./README_zh.md)
@ -21,10 +21,18 @@
</p>
## 更新とニュース
- 🔥🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B`
- 🔥🔥 ```2025/03/24```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
- **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
- **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAMビデオメモリで動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
- **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
- **ニュース**: ```2024/11/08```: `CogVideoX1.5` モデルをリリースしました。CogVideoX1.5 は CogVideoX オープンソースモデルのアップグレードバージョンです。
CogVideoX1.5-5B シリーズモデルは、10秒 長の動画とより高い解像度をサポートしており、`CogVideoX1.5-5B-I2V` は任意の解像度での動画生成に対応しています。
SAT コードはすでに更新されており、`diffusers` バージョンは現在適応中です。
SAT バージョンのコードは [こちら](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) からダウンロードできます。
- 🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B`
を微調整できるフレームワーク [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)
がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!- 🔥**ニュース**: ```2024/10/10```:
がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!
- 🔥**ニュース**: ```2024/10/10```:
技術報告書を更新し、より詳細なトレーニング情報とデモを追加しました。
- 🔥 **ニュース**: ```2024/10/10```: 技術報告書を更新しました。[こちら](https://arxiv.org/pdf/2408.06072)
をクリックしてご覧ください。さらにトレーニングの詳細とデモを追加しました。デモを見るには[こちら](https://yzy-thu.github.io/CogVideoX-demo/)
@ -34,7 +42,7 @@
- 🔥**ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V**
をオープンソース化しました。このモデルは、画像を背景入力として使用し、プロンプトワードと組み合わせてビデオを生成することができ、より高い制御性を提供します。これにより、CogVideoXシリーズのモデルは、テキストからビデオ生成、ビデオの継続、画像からビデオ生成の3つのタスクをサポートするようになりました。オンラインでの[体験](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)
をお楽しみください。
- 🔥🔥 **ニュース**: ```2024/9/19```:
- 🔥 **ニュース**: ```2024/9/19```:
CogVideoXのトレーニングプロセスでビデオデータをテキスト記述に変換するために使用されるキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
をオープンソース化しました。ダウンロードしてご利用ください。
- 🔥 ```2024/8/27```: CogVideoXシリーズのより大きなモデル **CogVideoX-5B**
@ -56,18 +64,23 @@
特定のセクションにジャンプ:
- [クイックスタート](#クイックスタート)
- [プロンプトの最適化](#プロンプトの最適化)
- [SAT](#sat)
- [Diffusers](#Diffusers)
- [CogVideoX-2B ギャラリー](#CogVideoX-2B-ギャラリー)
- [Diffusers](#diffusers)
- [Gallery](#gallery)
- [CogVideoX-5B](#cogvideox-5b)
- [CogVideoX-2B](#cogvideox-2b)
- [モデル紹介](#モデル紹介)
- [友好的リンク](#友好的リンク)
- [プロジェクト構造](#プロジェクト構造)
- [推論](#推論)
- [sat](#sat)
- [Colabでのクイックスタート](#colabでのクイックスタート)
- [Inference](#inference)
- [finetune](#finetune)
- [sat](#sat-1)
- [ツール](#ツール)
- [プロジェクト計画](#プロジェクト計画)
- [モデルライセンス](#モデルライセンス)
- [CogVideo(ICLR'23)モデル紹介](#CogVideoICLR23)
- [CogVideo(ICLR'23)](#cogvideoiclr23)
- [引用](#引用)
- [ライセンス契約](#ライセンス契約)
## クイックスタート
@ -159,76 +172,93 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
<table style="border-collapse: collapse; width: 100%;">
<tr>
<th style="text-align: center;">モデル名</th>
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
<th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V </th>
<th style="text-align: center;">CogVideoX-5B-I2V</th>
</tr>
<tr>
<td style="text-align: center;">公開日</td>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年8月6日</th>
<th style="text-align: center;">2024年8月27日</th>
<th style="text-align: center;">2024年9月19日</th>
</tr>
<tr>
<td style="text-align: center;">ビデオ解像度</td>
<td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;"> Min(W, H) = 768 <br> 768 ≤ Max(W, H) ≤ 1360 <br> Max(W, H) % 16 = 0 </td>
<td colspan="3" style="text-align: center;">720 * 480</td>
</tr>
<tr>
<td style="text-align: center;">フレーム数</td>
<td colspan="2" style="text-align: center;"><b>16N + 1</b> (N <= 10) である必要があります (デフォルト 81)</td>
<td colspan="3" style="text-align: center;"><b>8N + 1</b> (N <= 6) である必要があります (デフォルト 49)</td>
</tr>
<tr>
<td style="text-align: center;">推論精度</td>
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32, FP8*, INT8, INT4は非対応</td>
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32, FP8*, INT8, INT4は非対応</td>
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32FP8*INT8INT4非対応</td>
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32FP8*INT8INT4非対応</td>
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32FP8*INT8INT4非対応</td>
</tr>
<tr>
<td style="text-align: center;">単一GPUのメモリ消費<br></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GBから* </b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GBから* </b><br><b>diffusers INT8(torchao): 4.4GBから* </b></td>
<td style="text-align: center;">単一GPUメモリ消費量<br></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 76GB <br><b>diffusers BF1610GBから*</b><br><b>diffusers INT8(torchao)7GBから*</b></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB以上* </b><br><b>diffusers INT8(torchao): 3.6GB以上*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB以上* </b><br><b>diffusers INT8(torchao): 4.4GB以上* </b></td>
</tr>
<tr>
<td style="text-align: center;">マルチGPUのメモリ消費</td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
<td style="text-align: center;">複数GPU推論メモリ消費量</td>
<td colspan="2" style="text-align: center;"><b>BF16: 24GB* using diffusers</b><br></td>
<td style="text-align: center;"><b>FP16: 10GB* diffusers使用</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* diffusers使用</b><br></td>
</tr>
<tr>
<td style="text-align: center;">推論速度<br>(ステップ = 50, FP/BF16)</td>
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td>
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90秒</td>
</tr>
<tr>
<td style="text-align: center;">ファインチューニング精度</td>
<td style="text-align: center;"><b>FP16</b></td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr>
<tr>
<td style="text-align: center;">ファインチューニング時のメモリ消費</td>
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
<td style="text-align: center;">推論速度<br>(Step = 50, FP/BF16)</td>
<td colspan="2" style="text-align: center;">シングルA100: ~1000秒(5秒ビデオ)<br>シングルH100: ~550秒(5秒ビデオ)</td>
<td style="text-align: center;">シングルA100: ~90秒<br>シングルH100: ~45秒</td>
<td colspan="2" style="text-align: center;">シングルA100: ~180秒<br>シングルH100: ~90秒</td>
</tr>
<tr>
<td style="text-align: center;">プロンプト言語</td>
<td colspan="3" style="text-align: center;">英語*</td>
<td colspan="5" style="text-align: center;">英語*</td>
</tr>
<tr>
<td style="text-align: center;">プロンプトの最大トークン数</td>
<td style="text-align: center;">プロンプト長さの上限</td>
<td colspan="2" style="text-align: center;">224トークン</td>
<td colspan="3" style="text-align: center;">226トークン</td>
</tr>
<tr>
<td style="text-align: center;">ビデオの長さ</td>
<td style="text-align: center;">ビデオ長さ</td>
<td colspan="2" style="text-align: center;">5秒または10秒</td>
<td colspan="3" style="text-align: center;">6秒</td>
</tr>
<tr>
<td style="text-align: center;">フレームレート</td>
<td colspan="2" style="text-align: center;">16フレーム/秒</td>
<td colspan="3" style="text-align: center;">8フレーム/秒</td>
</tr>
<tr>
<td style="text-align: center;">ビデオ解像度</td>
<td colspan="3" style="text-align: center;">720 * 480、他の解像度は非対応(ファインチューニング含む)</td>
</tr>
<tr>
<td style="text-align: center;">位置エンコーディング</td>
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr>
<tr>
<td style="text-align: center;">ダウンロードリンク (Diffusers)</td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
</tr>
<tr>
<td style="text-align: center;">ダウンロードリンク (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_ja.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
</tr>
</table>
@ -253,18 +283,18 @@ pipe.vae.enable_tiling()
は、CogVideoXのメモリ要件を削減するためにテキストエンコーダ、トランスフォーマ、およびVAEモジュールを量子化するために使用できます。これにより、無料のT4
Colabやより少ないメモリのGPUでモデルを実行することが可能になります。同様に重要なのは、TorchAOの量子化は`torch.compile`
と完全に互換性があり、推論速度を大幅に向上させることができる点です。`NVIDIA H100`およびそれ以上のデバイスでは`FP8`
精度を使用する必要があります。これには、`torch``torchao``diffusers``accelerate`
Pythonパッケージのソースコードからのインストールが必要です。`CUDA 12.4`の使用をお勧めします。
精度を使用する必要があります。これには、`torch``torchao` Pythonパッケージのソースコードからのインストールが必要です。`CUDA 12.4`の使用をお勧めします。
+ 推論速度テストも同様に、上記のメモリ最適化方法を使用しています。メモリ最適化を使用しない場合、推論速度は約10向上します。
`diffusers`バージョンのモデルのみが量子化をサポートしています。
+ モデルは英語入力のみをサポートしており、他の言語は大規模モデルの改善を通じて英語に翻訳できます。
+ モデルのファインチューニングに使用されるメモリは`8 * H100`環境でテストされています。プログラムは自動的に`Zero 2`
最適化を使用しています。表に具体的なGPU数が記載されている場合、ファインチューニングにはその数以上のGPUが必要です。
## 友好的リンク
コミュニティからの貢献を大歓迎し、私たちもオープンソースコミュニティに積極的に貢献しています。以下の作品はすでにCogVideoXに対応しており、ぜひご利用ください
+ [RIFLEx-CogVideoX](https://github.com/thu-ml/RIFLEx)
RIFLExは動画の長さを外挿する手法で、たった1行のコードで動画の長さを元の2倍に延長できます。RIFLExはトレーニング不要の推論をサポートするだけでなく、CogVideoXをベースにファインチューニングしたモデルも提供しています。元の長さの動画でわずか1000ステップのファインチューニングを行うだけで、長さ外挿能力を大幅に向上させることができます。
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun):
CogVideoX-Funは、CogVideoXアーキテクチャを基にした改良パイプラインで、自由な解像度と複数の起動方法をサポートしています。
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): CogVideo の Gradio Web UI の別のリポジトリ。より高機能な Web
@ -284,6 +314,10 @@ pipe.vae.enable_tiling()
キーフレーム補間生成において、より大きな柔軟性を提供することを目的とした、CogVideoX構造を基にした修正版のパイプライン。
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): DiffSynth
Studioは、拡散エンジンです。テキストエンコーダー、UNet、VAEなどを含むアーキテクチャを再構築し、オープンソースコミュニティモデルとの互換性を維持しつつ、計算性能を向上させました。このフレームワークはCogVideoXに適応しています。
+ [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): CogVideoXモデルを含むシンプルなControlNetモジュールのコード。
+ [VideoTuna](https://github.com/VideoVerses/VideoTuna): VideoTuna は、テキストからビデオ、画像からビデオ、テキストから画像生成のための複数のAIビデオ生成モデルを統合した最初のリポジトリです。
+ [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): 一貫性のある顔を保持するために、周波数分解を使用するCogVideoX-5Bに基づいたアイデンティティ保持型テキストから動画生成モデル。
+ [ステップバイステップチュートリアル](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): WindowsおよびクラウドでのCogVideoX1.5-5B-I2Vモデルのインストールと最適化に関するステップバイステップガイド。[FurkanGozukara](https://github.com/FurkanGozukara)氏の尽力とサポートに感謝いたします!
## プロジェクト構造
@ -386,8 +420,6 @@ CogVideoのデモは [https://models.aminer.cn/cogvideo](https://models.aminer.c
}
```
あなたの貢献をお待ちしています!詳細は[こちら](resources/contribute_ja.md)をクリックしてください。
## ライセンス契約
このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。

View File

@ -1,10 +1,9 @@
# CogVideo & CogVideoX
[Read this in English](./README_zh.md)
[Read this in English](./README.md)
[日本語で読む](./README_ja.md)
<div align="center">
<img src=resources/logo.svg width="50%"/>
</div>
@ -23,7 +22,13 @@
## 项目更新
- 🔥🔥 **News**: ```2024/10/13```: 成本更低单卡4090可微调`CogVideoX-5B`
- 🔥🔥 **News**: ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
- 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py).
- 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本仅需调整部分参数仅可沿用之前的代码。
- 🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。
CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨率,其中 `CogVideoX1.5-5B-I2V` 支持 **任意分辨率** 的视频生成SAT代码已经更新。`diffusers`版本还在适配中。SAT版本代码前往 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 下载。
- 🔥**News**: ```2024/10/13```: 成本更低单卡4090可微调 `CogVideoX-5B`
的微调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经推出,多种分辨率微调,欢迎使用。
- 🔥 **News**: ```2024/10/10```: 我们更新了我们的技术报告,请点击 [这里](https://arxiv.org/pdf/2408.06072)
查看附上了更多的训练细节和demo关于demo点击[这里](https://yzy-thu.github.io/CogVideoX-demo/) 查看。
@ -38,8 +43,7 @@
- 🔥 ```2024/8/27```: 我们开源 CogVideoX 系列更大的模型 **CogVideoX-5B**
。我们大幅度优化了模型的推理性能,推理门槛大幅降低,您可以在 `GTX 1080TI` 等早期显卡运行 **CogVideoX-2B**,在 `RTX 3060`
等桌面端甜品卡运行 **CogVideoX-5B** 模型。 请严格按照[要求](requirements.txt)
更新安装依赖,推理代码请查看 [cli_demo](inference/cli_demo.py)。同时,**CogVideoX-2B** 模型开源协议已经修改为**Apache 2.0
协议**。
更新安装依赖,推理代码请查看 [cli_demo](inference/cli_demo.py)。同时,**CogVideoX-2B** 模型开源协议已经修改为**Apache 2.0 协议**。
- 🔥 ```2024/8/6```: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。
- 🔥 ```2024/8/6```: 我们开源 CogVideoX 系列视频生成模型的第一个模型, **CogVideoX-2B**
- 🌱 **Source**: ```2022/5/19```: 我们开源了 CogVideo 视频生成模型(现在你可以在 `CogVideo` 分支中看到),这是首个开源的基于
@ -50,18 +54,23 @@
跳转到指定部分:
- [快速开始](#快速开始)
- [提示词优化](#提示词优化)
- [SAT](#sat)
- [Diffusers](#Diffusers)
- [CogVideoX-2B 视频作品](#cogvideox-2b-视频作品)
- [CogVideoX模型介绍](#模型介绍)
- [Diffusers](#diffusers)
- [视频作品](#视频作品)
- [CogVideoX-5B](#cogvideox-5b)
- [CogVideoX-2B](#cogvideox-2b)
- [模型介绍](#模型介绍)
- [友情链接](#友情链接)
- [完整项目代码结构](#完整项目代码结构)
- [Inference](#inference)
- [SAT](#sat)
- [Tools](#tools)
- [开源项目规划](#开源项目规划)
- [模型协议](#模型协议)
- [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23)
- [Colab 快速使用](#colab-快速使用)
- [inference](#inference)
- [finetune](#finetune)
- [sat](#sat-1)
- [tools](#tools)
- [CogVideo(ICLR'23)](#cogvideoiclr23)
- [引用](#引用)
- [模型协议](#模型协议)
## 快速开始
@ -154,75 +163,92 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
<table style="border-collapse: collapse; width: 100%;">
<tr>
<th style="text-align: center;">模型名</th>
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
<th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V </th>
</tr>
<tr>
<td style="text-align: center;">发布时间</td>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年8月6日</th>
<th style="text-align: center;">2024年8月27日</th>
<th style="text-align: center;">2024年9月19日</th>
</tr>
<tr>
<td style="text-align: center;">视频分辨率</td>
<td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;"> Min(W, H) = 768 <br> 768 ≤ Max(W, H) ≤ 1360 <br> Max(W, H) % 16 = 0 </td>
<td colspan="3" style="text-align: center;">720 * 480</td>
</tr>
<tr>
<td style="text-align: center;">帧数</td>
<td colspan="2" style="text-align: center;">必须为 <b>16N + 1</b> 其中 N <= 10 (默认 81)</td>
<td colspan="3" style="text-align: center;">必须为 <b>8N + 1</b> 其中 N <= 6 (默认 49)</td>
</tr>
<tr>
<td style="text-align: center;">推理精度</td>
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32FP8*INT8不支持INT4</td>
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32FP8*INT8不支持INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32FP8*INT8不支持INT4</td>
</tr>
<tr>
<td style="text-align: center;">单GPU显存消耗<br></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 76GB <br><b>diffusers BF16 : 10GB起* </b><br><b>diffusers INT8(torchao): 7G起* </b></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td>
</tr>
<tr>
<td style="text-align: center;">多GPU推理显存消耗</td>
<td colspan="2" style="text-align: center;"><b>BF16: 24GB* using diffusers</b><br></td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
</tr>
<tr>
<td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td>
<td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
<td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td>
<td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td>
</tr>
<tr>
<td style="text-align: center;">微调精度</td>
<td style="text-align: center;"><b>FP16</b></td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr>
<tr>
<td style="text-align: center;">微调显存消耗</td>
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
</tr>
<tr>
<td style="text-align: center;">提示词语言</td>
<td colspan="3" style="text-align: center;">English*</td>
<td colspan="5" style="text-align: center;">English*</td>
</tr>
<tr>
<td style="text-align: center;">提示词长度上限</td>
<td colspan="2" style="text-align: center;">224 Tokens</td>
<td colspan="3" style="text-align: center;">226 Tokens</td>
</tr>
<tr>
<td style="text-align: center;">视频长度</td>
<td colspan="2" style="text-align: center;">5 秒 或 10 秒</td>
<td colspan="3" style="text-align: center;">6 秒</td>
</tr>
<tr>
<td style="text-align: center;">帧率</td>
<td colspan="2" style="text-align: center;">16 帧 / 秒 </td>
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
</tr>
<tr>
<td style="text-align: center;">视频分辨率</td>
<td colspan="3" style="text-align: center;">720 * 480不支持其他分辨率(含微调)</td>
</tr>
<tr>
<td style="text-align: center;">位置编码</td>
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr>
<tr>
<td style="text-align: center;">下载链接 (Diffusers)</td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
</tr>
<tr>
<td style="text-align: center;">下载链接 (SAT)</td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
</tr>
</table>
@ -245,16 +271,16 @@ pipe.vae.enable_tiling()
+ [PytorchAO](https://github.com/pytorch/ao) 和 [Optimum-quanto](https://github.com/huggingface/optimum-quanto/)
可以用于量化文本编码器、Transformer 和 VAE 模块,以降低 CogVideoX 的内存需求。这使得在免费的 T4 Colab 或更小显存的 GPU
上运行模型成为可能同样值得注意的是TorchAO 量化完全兼容 `torch.compile`,这可以显著提高推理速度。在 `NVIDIA H100`
及以上设备上必须使用 `FP8` 精度,这需要源码安装 `torch``torchao``diffusers``accelerate` Python
包。建议使用 `CUDA 12.4`
及以上设备上必须使用 `FP8` 精度,这需要源码安装 `torch``torchao` Python 包。建议使用 `CUDA 12.4`
+ 推理速度测试同样采用了上述显存优化方案不采用显存优化的情况下推理速度提升约10%。 只有`diffusers`版本模型支持量化。
+ 模型仅支持英语输入,其他语言可以通过大模型润色时翻译为英语。
+ 模型微调所占用的显存是在 `8 * H100` 环境下进行测试,程序已经自动使用`Zero 2` 优化。表格中若有标注具体GPU数量则必须使用大于等于该数量的GPU进行微调。
## 友情链接
我们非常欢迎来自社区的贡献并积极的贡献开源社区。以下作品已经对CogVideoX进行了适配欢迎大家使用:
+ [RIFLEx-CogVideoX](https://github.com/thu-ml/RIFLEx)
RIFLEx 是一个视频长度外推的方法只需一行代码即可将视频生成长度延伸为原先的二倍。RIFLEx 不仅支持 Training-free 的推理,也提供基于 CogVideoX 进行微调的模型,只需在原有长度视频上微调 1000 步即可大大提高长度外推能力。
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun):
CogVideoX-Fun是一个基于CogVideoX结构修改后的的pipeline支持自由的分辨率多种启动方式。
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): CogVideo 的 Gradio Web UI单独实现仓库支持更多功能的 Web UI。
@ -269,6 +295,11 @@ pipe.vae.enable_tiling()
+ [CogVideoX-Interpolation](https://github.com/feizc/CogvideX-Interpolation): 基于 CogVideoX 结构修改的管道,旨在为关键帧插值生成提供更大的灵活性。
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): DiffSynth 工作室是一款扩散引擎。重构了架构包括文本编码器、UNet、VAE
等,在保持与开源社区模型兼容性的同时,提升了计算性能。该框架已经适配 CogVideoX。
+ [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): 一个包含 CogvideoX 模型的简单 Controlnet 模块的代码。
+ [VideoTuna](https://github.com/VideoVerses/VideoTuna)VideoTuna 是首个集成多种 AI 视频生成模型的仓库,支持文本转视频、图像转视频、文本转图像生成。
+ [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): 一种身份保持的文本到视频生成模型,基于 CogVideoX-5B通过频率分解在生成的视频中保持面部一致性。
+ [教程](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): 一个关于在Windows和云环境中安装和优化CogVideoX1.5-5B-I2V模型的分步指南。特别感谢[FurkanGozukara](https://github.com/FurkanGozukara)的努力和支持!
## 完整项目代码结构
@ -369,8 +400,6 @@ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.amine
}
```
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
## 模型协议
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。

View File

@ -1,126 +1,146 @@
# CogVideoX diffusers Fine-tuning Guide
# CogVideoX Diffusers Fine-tuning Guide
[中文阅读](./README_zh.md)
[日本語で読む](./README_ja.md)
This feature is not fully complete yet. If you want to check the fine-tuning for the SAT version, please
see [here](../sat/README_zh.md). The dataset format is different from this version.
If you're looking for the fine-tuning instructions for the SAT version, please check [here](../sat/README_zh.md). The
dataset format for this version differs from the one used here.
## Hardware Requirements
+ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (Working)
+ CogVideoX-5B-I2V is not supported yet.
| Model | Training Type | Distribution Strategy | Mixed Precision | Training Resolution (FxHxW) | Hardware Requirements |
|----------------------------|----------------|--------------------------------------|-----------------|-----------------------------|-------------------------|
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36GB VRAM (NVIDIA A100) |
| cogvideox-t2v-2b | sft | 1-GPU zero-2 + opt offload | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8-GPU zero-2 | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8-GPU zero-3 | fp16 | 49x480x720 | 19GB VRAM (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 49x480x720 | 14GB VRAM (NVIDIA 4080) |
| cogvideox-{t2v, i2v}-5b | sft | 1-GPU zero-2 + opt offload | bf16 | 49x480x720 | 42GB VRAM (NVIDIA A100) |
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-2 | bf16 | 49x480x720 | 42GB VRAM (NVIDIA 4090) |
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-3 | bf16 | 49x480x720 | 43GB VRAM (NVIDIA 4090) |
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 49x480x720 | 28GB VRAM (NVIDIA 5090) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 1-GPU zero-2 + opt offload | bf16 | 81x768x1360 | 56GB VRAM (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-2 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-3 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 81x768x1360 | 40GB VRAM (NVIDIA A100) |
## Install Dependencies
Since the related code has not been merged into the diffusers release, you need to base your fine-tuning on the
diffusers branch. Please follow the steps below to install dependencies:
Since the relevant code has not yet been merged into the official `diffusers` release, you need to fine-tune based on
the diffusers branch. Follow the steps below to install the dependencies:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers # Now in Main branch
cd diffusers # Now on the Main branch
pip install -e .
```
## Prepare the Dataset
First, you need to prepare the dataset. The dataset format should be as follows, with `videos.txt` containing the list
of videos in the `videos` directory:
First, you need to prepare your dataset. Depending on your task type (T2V or I2V), the dataset format will vary
slightly:
```
.
├── prompts.txt
├── videos
└── videos.txt
├── videos.txt
├── images # (Optional) For I2V, if not provided, first frame will be extracted from video as reference
└── images.txt # (Optional) For I2V, if not provided, first frame will be extracted from video as reference
```
You can download
the [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) dataset from
here.
Where:
This video fine-tuning dataset is used as a test for fine-tuning.
- `prompts.txt`: Contains the prompts
- `videos/`: Contains the .mp4 video files
- `videos.txt`: Contains the list of video files in the `videos/` directory
- `images/`: (Optional) Contains the .png reference image files
- `images.txt`: (Optional) Contains the list of reference image files
## Configuration Files and Execution
You can download a sample dataset (
T2V) [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset).
The `accelerate` configuration files are as follows:
If you need to use a validation dataset during training, make sure to provide a validation dataset with the same format
as the training dataset.
+ `accelerate_config_machine_multi.yaml`: Suitable for multi-GPU use
+ `accelerate_config_machine_single.yaml`: Suitable for single-GPU use
## Running Scripts to Start Fine-tuning
The configuration for the `finetune` script is as follows:
Before starting training, please note the following resolution requirements:
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # Use accelerate to launch multi-GPU training with the config file accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # Training script train_cogvideox_lora.py for LoRA fine-tuning on CogVideoX model
--gradient_checkpointing \ # Enable gradient checkpointing to reduce memory usage
--pretrained_model_name_or_path $MODEL_PATH \ # Path to the pretrained model, specified by $MODEL_PATH
--cache_dir $CACHE_PATH \ # Cache directory for model files, specified by $CACHE_PATH
--enable_tiling \ # Enable tiling technique to process videos in chunks, saving memory
--enable_slicing \ # Enable slicing to further optimize memory by slicing inputs
--instance_data_root $DATASET_PATH \ # Dataset path specified by $DATASET_PATH
--caption_column prompts.txt \ # Specify the file prompts.txt for video descriptions used in training
--video_column videos.txt \ # Specify the file videos.txt for video paths used in training
--validation_prompt "" \ # Prompt used for generating validation videos during training
--validation_prompt_separator ::: \ # Set ::: as the separator for validation prompts
--num_validation_videos 1 \ # Generate 1 validation video per validation round
--validation_epochs 100 \ # Perform validation every 100 training epochs
--seed 42 \ # Set random seed to 42 for reproducibility
--rank 128 \ # Set the rank for LoRA parameters to 128
--lora_alpha 64 \ # Set the alpha parameter for LoRA to 64, adjusting LoRA learning rate
--mixed_precision bf16 \ # Use bf16 mixed precision for training to save memory
--output_dir $OUTPUT_PATH \ # Specify the output directory for the model, defined by $OUTPUT_PATH
--height 480 \ # Set video height to 480 pixels
--width 720 \ # Set video width to 720 pixels
--fps 8 \ # Set video frame rate to 8 frames per second
--max_num_frames 49 \ # Set the maximum number of frames per video to 49
--skip_frames_start 0 \ # Skip 0 frames at the start of the video
--skip_frames_end 0 \ # Skip 0 frames at the end of the video
--train_batch_size 4 \ # Set training batch size to 4
--num_train_epochs 30 \ # Total number of training epochs set to 30
--checkpointing_steps 1000 \ # Save model checkpoint every 1000 steps
--gradient_accumulation_steps 1 \ # Accumulate gradients for 1 step, updating after each batch
--learning_rate 1e-3 \ # Set learning rate to 0.001
--lr_scheduler cosine_with_restarts \ # Use cosine learning rate scheduler with restarts
--lr_warmup_steps 200 \ # Warm up the learning rate for the first 200 steps
--lr_num_cycles 1 \ # Set the number of learning rate cycles to 1
--optimizer AdamW \ # Use the AdamW optimizer
--adam_beta1 0.9 \ # Set Adam optimizer beta1 parameter to 0.9
--adam_beta2 0.95 \ # Set Adam optimizer beta2 parameter to 0.95
--max_grad_norm 1.0 \ # Set maximum gradient clipping value to 1.0
--allow_tf32 \ # Enable TF32 to speed up training
--report_to wandb # Use Weights and Biases (wandb) for logging and monitoring the training
1. The number of frames must be a multiple of 8 **plus 1** (i.e., 8N+1), such as 49, 81 ...
2. Recommended video resolutions for each model:
- CogVideoX: 480x720 (height x width)
- CogVideoX1.5: 768x1360 (height x width)
3. For samples (videos or images) that don't match the training resolution, the code will directly resize them. This may
cause aspect ratio distortion and affect training results. It's recommended to preprocess your samples (e.g., using
crop + resize to maintain aspect ratio) before training.
> **Important Note**: To improve training efficiency, we automatically encode videos and cache the results on disk
> before training. If you modify the data after training, please delete the latent directory under the video directory to
> ensure the latest data is used.
### LoRA
```bash
# Modify configuration parameters in train_ddp_t2v.sh
# Main parameters to modify:
# --output_dir: Output directory
# --data_root: Dataset root directory
# --caption_column: Path to prompt file
# --image_column: Optional for I2V, path to reference image file list (remove this parameter to use the first frame of video as image condition)
# --video_column: Path to video file list
# --train_resolution: Training resolution (frames x height x width)
# For other important parameters, please refer to the launch script
bash train_ddp_t2v.sh # Text-to-Video (T2V) fine-tuning
bash train_ddp_i2v.sh # Image-to-Video (I2V) fine-tuning
```
## Running the Script to Start Fine-tuning
### SFT
Single Node (One GPU or Multi GPU) fine-tuning:
We provide several zero configuration templates in the `configs/` directory. Please choose the appropriate training
configuration based on your needs (configure the `deepspeed_config_file` option in `accelerate_config.yaml`).
```shell
bash finetune_single_rank.sh
```bash
# Parameters to configure are the same as LoRA training
bash train_zero_t2v.sh # Text-to-Video (T2V) fine-tuning
bash train_zero_i2v.sh # Image-to-Video (I2V) fine-tuning
```
Multi-Node fine-tuning:
In addition to setting the bash script parameters, you need to set the relevant training options in the zero
configuration file and ensure the zero training configuration matches the parameters in the bash script, such as
batch_size, gradient_accumulation_steps, mixed_precision. For details, please refer to
the [DeepSpeed official documentation](https://www.deepspeed.ai/docs/config-json/)
```shell
bash finetune_multi_rank.sh # Needs to be run on each node
```
When using SFT training, please note:
## Loading the Fine-tuned Model
1. For SFT training, model offload is not used during validation, so the peak VRAM usage may exceed 24GB. For GPUs with
less than 24GB VRAM, it's recommended to disable validation.
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for how to load the fine-tuned model.
2. Validation is slow when zero-3 is enabled, so it's recommended to disable validation when using zero-3.
## Load the Fine-tuned Model
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for instructions on how to load the fine-tuned model.
+ For SFT trained models, please first use the `zero_to_fp32.py` script in the `checkpoint-*/` directory to merge the
model weights
## Best Practices
+ Includes 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). By skipping frames in
the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experimentation, as the
maximum frame limit recommended by the CogVideoX team is 49 frames. We split the 70 videos into three groups of 10,
25, and 50 videos, with similar conceptual nature.
+ Using 25 or more videos works best when training new concepts and styles.
+ It works better to train using identifier tokens specified with `--id_token`. This is similar to Dreambooth training,
but regular fine-tuning without such tokens also works.
+ The original repository used `lora_alpha` set to 1. We found this value ineffective across multiple runs, likely due
to differences in the backend and training setup. Our recommendation is to set `lora_alpha` equal to rank or rank //
2.
+ We recommend using a rank of 64 or higher.
+ We included 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). Through frame
skipping in the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experiments. The
maximum frame count recommended by the CogVideoX team is 49 frames. These 70 videos were divided into three groups:
10, 25, and 50 videos, with similar conceptual nature.
+ Videos with 25 or more frames work best for training new concepts and styles.
+ It's recommended to use an identifier token, which can be specified using `--id_token`, for better training results.
This is similar to Dreambooth training, though regular fine-tuning without using this token will still work.
+ The original repository uses `lora_alpha` set to 1. We found that this value performed poorly in several runs,
possibly due to differences in the model backend and training settings. Our recommendation is to set `lora_alpha` to
be equal to the rank or `rank // 2`.
+ It's advised to use a rank of 64 or higher.

View File

@ -1,116 +1,124 @@
# CogVideoX diffusers 微調整方法
[Read this in English.](./README_zh)
# CogVideoX Diffusers ファインチューニングガイド
[中文阅读](./README_zh.md)
[Read in English](./README.md)
この機能はまだ完全に完成していません。SATバージョンの微調整を確認したい場合は、[こちら](../sat/README_ja.md)を参照してください。本バージョンとは異なるデータセット形式を使用しています。
SATバージョンのファインチューニング手順については、[こちら](../sat/README_zh.md)をご確認ください。このバージョンのデータセットフォーマットは、こちらのバージョンとは異なります。
## ハードウェア要件
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (動作確認済み)
+ CogVideoX-5B-I2V まだサポートしていません
| モデル | トレーニングタイプ | 分散戦略 | 混合トレーニング精度 | トレーニング解像度(フレーム数×高さ×幅) | ハードウェア要件 |
|----------------------------|-------------------|----------------------------------|-----------------|-----------------------------|----------------------------|
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36GB VRAM (NVIDIA A100) |
| cogvideox-t2v-2b | sft | 1カード zero-2 + オプティマイゼーションオフロード | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8カード zero-2 | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8カード zero-3 | fp16 | 49x480x720 | 19GB VRAM (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 49x480x720 | 14GB VRAM (NVIDIA 4080) |
| cogvideox-{t2v, i2v}-5b | sft | 1カード zero-2 + オプティマイゼーションオフロード | bf16 | 49x480x720 | 42GB VRAM (NVIDIA A100) |
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-2 | bf16 | 49x480x720 | 42GB VRAM (NVIDIA 4090) |
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-3 | bf16 | 49x480x720 | 43GB VRAM (NVIDIA 4090) |
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 49x480x720 | 28GB VRAM (NVIDIA 5090) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 1カード zero-2 + オプティマイゼーションオフロード | bf16 | 81x768x1360 | 56GB VRAM (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-2 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-3 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 81x768x1360 | 40GB VRAM (NVIDIA A100) |
## 依存関係のインストール
関連コードはまだdiffusersのリリース版に統合されていないため、diffusersブランチを使用して微調整を行う必要があります。以下の手順に従って依存関係をインストールしてください
関連するコードがまだ `diffusers` の公式リリースに統合されていないため、`diffusers` ブランチを基にファインチューニングを行う必要があります。以下の手順に従って依存関係をインストールしてください:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers # Now in Main branch
cd diffusers # 現在は Main ブランチ
pip install -e .
```
## データセットの準備
まず、データセットを準備する必要があります。データセットの形式は以下のようになります。
まず、データセットを準備する必要があります。タスクの種類T2V または I2Vによって、データセットのフォーマットが少し異なります
```
.
├── prompts.txt
├── videos
└── videos.txt
├── videos.txt
├── images # (オプション) I2Vの場合。提供されない場合、動画の最初のフレームが参照画像として使用されます
└── images.txt # (オプション) I2Vの場合。提供されない場合、動画の最初のフレームが参照画像として使用されます
```
[ディズニースチームボートウィリー](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)をここからダウンロードできます。
各ファイルの役割は以下の通りです:
- `prompts.txt`: プロンプトを格納
- `videos/`: .mp4 動画ファイルを格納
- `videos.txt`: `videos/` フォルダ内の動画ファイルリストを格納
- `images/`: (オプション) .png 形式の参照画像ファイル
- `images.txt`: (オプション) 参照画像ファイルリスト
ビデオ微調整データセットはテスト用として使用されます。
トレーニング中に検証データセットを使用する場合は、トレーニングデータセットと同じフォーマットで検証データセットを提供する必要があります。
## 設定ファイルと実行
## スクリプトを実行してファインチューニングを開始
`accelerate` 設定ファイルは以下の通りです:
トレーニングを開始する前に、以下の解像度設定要件に注意してください:
+ accelerate_config_machine_multi.yaml 複数GPU向け
+ accelerate_config_machine_single.yaml 単一GPU向け
1. フレーム数は8の倍数 **+1** (つまり8N+1) でなければなりません。例49, 81 ...
2. ビデオ解像度はモデルのデフォルトサイズを使用することをお勧めします:
- CogVideoX: 480x720 (高さ×幅)
- CogVideoX1.5: 768x1360 (高さ×幅)
3. トレーニング解像度に合わないサンプル(ビデオや画像)はコード内で自動的にリサイズされます。このため、サンプルのアスペクト比が変形し、トレーニング効果に影響を与える可能性があります。解像度に関しては、事前にサンプルを処理(例えば、アスペクト比を維持するためにクロップ+リサイズを使用)してからトレーニングを行うことをお勧めします。
`finetune` スクリプト設定ファイルの例:
> **重要な注意**:トレーニング効率を高めるため、トレーニング前にビデオをエンコードし、その結果をディスクにキャッシュします。トレーニング後にデータを変更した場合は、`video`ディレクトリ内の`latent`ディレクトリを削除して、最新のデータを使用するようにしてください。
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # accelerateを使用してmulti-GPUトレーニングを起動、設定ファイルはaccelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # LoRAの微調整用のトレーニングスクリプトtrain_cogvideox_lora.pyを実行
--gradient_checkpointing \ # メモリ使用量を減らすためにgradient checkpointingを有効化
--pretrained_model_name_or_path $MODEL_PATH \ # 事前学習済みモデルのパスを$MODEL_PATHで指定
--cache_dir $CACHE_PATH \ # モデルファイルのキャッシュディレクトリを$CACHE_PATHで指定
--enable_tiling \ # メモリ節約のためにタイル処理を有効化し、動画をチャンク分けして処理
--enable_slicing \ # 入力をスライスしてさらにメモリ最適化
--instance_data_root $DATASET_PATH \ # データセットのパスを$DATASET_PATHで指定
--caption_column prompts.txt \ # トレーニングで使用する動画の説明ファイルをprompts.txtで指定
--video_column videos.txt \ # トレーニングで使用する動画のパスファイルをvideos.txtで指定
--validation_prompt "" \ # トレーニング中に検証用の動画を生成する際のプロンプト
--validation_prompt_separator ::: \ # 検証プロンプトの区切り文字を:::に設定
--num_validation_videos 1 \ # 各検証ラウンドで1本の動画を生成
--validation_epochs 100 \ # 100エポックごとに検証を実施
--seed 42 \ # 再現性を保証するためにランダムシードを42に設定
--rank 128 \ # LoRAのパラメータのランクを128に設定
--lora_alpha 64 \ # LoRAのalphaパラメータを64に設定し、LoRAの学習率を調整
--mixed_precision bf16 \ # bf16混合精度でトレーニングし、メモリを節約
--output_dir $OUTPUT_PATH \ # モデルの出力ディレクトリを$OUTPUT_PATHで指定
--height 480 \ # 動画の高さを480ピクセルに設定
--width 720 \ # 動画の幅を720ピクセルに設定
--fps 8 \ # 動画のフレームレートを1秒あたり8フレームに設定
--max_num_frames 49 \ # 各動画の最大フレーム数を49に設定
--skip_frames_start 0 \ # 動画の最初のフレームを0スキップ
--skip_frames_end 0 \ # 動画の最後のフレームを0スキップ
--train_batch_size 4 \ # トレーニングのバッチサイズを4に設定
--num_train_epochs 30 \ # 総トレーニングエポック数を30に設定
--checkpointing_steps 1000 \ # 1000ステップごとにモデルのチェックポイントを保存
--gradient_accumulation_steps 1 \ # 1ステップの勾配累積を行い、各バッチ後に更新
--learning_rate 1e-3 \ # 学習率を0.001に設定
--lr_scheduler cosine_with_restarts \ # リスタート付きのコサイン学習率スケジューラを使用
--lr_warmup_steps 200 \ # トレーニングの最初の200ステップで学習率をウォームアップ
--lr_num_cycles 1 \ # 学習率のサイクル数を1に設定
--optimizer AdamW \ # AdamWオプティマイザーを使用
--adam_beta1 0.9 \ # Adamオプティマイザーのbeta1パラメータを0.9に設定
--adam_beta2 0.95 \ # Adamオプティマイザーのbeta2パラメータを0.95に設定
--max_grad_norm 1.0 \ # 勾配クリッピングの最大値を1.0に設定
--allow_tf32 \ # トレーニングを高速化するためにTF32を有効化
--report_to wandb # Weights and Biasesを使用してトレーニングの記録とモニタリングを行う
### LoRA
```bash
# train_ddp_t2v.sh の設定パラメータを変更
# 主に以下のパラメータを変更する必要があります:
# --output_dir: 出力ディレクトリ
# --data_root: データセットのルートディレクトリ
# --caption_column: テキストプロンプトのファイルパス
# --image_column: I2Vの場合、参照画像のファイルリストのパスこのパラメータを削除すると、デフォルトで動画の最初のフレームが画像条件として使用されます
# --video_column: 動画ファイルのリストのパス
# --train_resolution: トレーニング解像度(フレーム数×高さ×幅)
# その他の重要なパラメータについては、起動スクリプトを参照してください
bash train_ddp_t2v.sh # テキストから動画T2V微調整
bash train_ddp_i2v.sh # 画像から動画I2V微調整
```
## 微調整を開始
### SFT
単一マシン (シングルGPU、マルチGPU) での微調整:
`configs/`ディレクトリにはいくつかのZero構成テンプレートが提供されています。必要に応じて適切なトレーニング設定を選択してください`accelerate_config.yaml``deepspeed_config_file`オプションを設定します)。
```shell
bash finetune_single_rank.sh
```bash
# 設定するパラメータはLoRAトレーニングと同様です
bash train_zero_t2v.sh # テキストから動画T2V微調整
bash train_zero_i2v.sh # 画像から動画I2V微調整
```
複数マシン・マルチGPUでの微調整
Bashスクリプトの関連パラメータを設定するだけでなく、Zeroの設定ファイルでトレーニングオプションを設定し、Zeroのトレーニング設定がBashスクリプト内のパラメータと一致していることを確認する必要があります。例えば、`batch_size``gradient_accumulation_steps``mixed_precision`など、具体的な詳細は[DeepSpeed公式ドキュメント](https://www.deepspeed.ai/docs/config-json/)を参照してください。
```shell
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。
```
SFTトレーニングを使用する際に注意すべき点
## 微調整済みモデルのロード
1. SFTトレーニングでは、検証時にモデルオフロードは使用されません。そのため、24GB以下のGPUでは検証時にVRAMのピークが24GBを超える可能性があります。24GB以下のGPUでは、検証を無効にすることをお勧めします。
+ 微調整済みのモデルをロードする方法については、[cli_demo.py](../inference/cli_demo.py) を参照してください。
2. Zero-3を有効にすると検証が遅くなるため、Zero-3では検証を無効にすることをお勧めします。
## ファインチューニングしたモデルの読み込み
+ ファインチューニングしたモデルを読み込む方法については、[cli_demo.py](../inference/cli_demo.py)を参照してください。
+ SFTトレーニングのモデルについては、まず`checkpoint-*`/ディレクトリ内の`zero_to_fp32.py`スクリプトを使用して、モデルの重みを統合してください。
## ベストプラクティス
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅のトレーニングビデオが70本含まれています。データ前処理でフレームをスキップすることで、49フレームと16フレームの小さなデータセットを作成しました。これは実験を加速するためのもので、CogVideoXチームが推奨する最大フレーム数制限は49フレームです。
+ 25本以上のビデオが新しい概念やスタイルのトレーニングに最適です。
+ 現在、`--id_token` を指定して識別トークンを使用してトレーニングする方が効果的です。これはDreamboothトレーニングに似ていますが、通常の微調整でも機能します。
+ 元のリポジトリでは `lora_alpha` を1に設定していましたが、複数の実行でこの値が効果的でないことがわかりました。モデルのバックエンドやトレーニング設定によるかもしれません。私たちの提案は、lora_alphaをrankと同じか、rank // 2に設定することです。
+ Rank 64以上の設定を推奨します。
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅)の70本のトレーニング動画を使用しました。データ前処理でフレームスキップを行い、49フレームおよび16フレームの2つの小さなデータセットを作成して実験速度を向上させました。CogVideoXチームの推奨最大フレーム数制限は49フレームです。これらの70本の動画は、10、25、50本の3つのグループに分け、概念的に類似した性質のものです。
+ 25本以上の動画を使用することで、新しい概念やスタイルのトレーニングが最適です。
+ `--id_token` で指定できる識別子トークンを使用すると、トレーニング効果がより良くなります。これはDreamboothトレーニングに似ていますが、このトークンを使用しない通常のファインチューニングでも問題なく動作します。
+ 元のリポジトリでは `lora_alpha` が1に設定されていますが、この値は多くの実行で効果が悪かったため、モデルのバックエンドやトレーニング設定の違いが影響している可能性があります。私たちの推奨は、`lora_alpha` を rank と同じか、`rank // 2` に設定することです。
+ rank は64以上に設定することをお勧めします。

View File

@ -1,16 +1,32 @@
# CogVideoX diffusers 微调方案
[Read this in English](./README_zh.md)
[Read this in English](./README.md)
[日本語で読む](./README_ja.md)
本功能尚未完全完善,如果您想查看SAT版本微调请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
如果您想查看SAT版本微调请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
## 硬件要求
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (制作中)
+ CogVideoX-5B-I2V 暂未支持
| 模型 | 训练类型 | 分布式策略 | 混合训练精度 | 训练分辨率(帧数x高x宽) | 硬件要求 |
|----------------------------|----------------|-----------------------------------|------------|----------------------|-----------------------|
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16G显存 (NVIDIA 4080) |
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24G显存 (NVIDIA 4090) |
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35G显存 (NVIDIA A100) |
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36G显存 (NVIDIA A100) |
| cogvideox-t2v-2b | sft | 1卡zero-2 + opt offload | fp16 | 49x480x720 | 17G显存 (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8卡zero-2 | fp16 | 49x480x720 | 17G显存 (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8卡zero-3 | fp16 | 49x480x720 | 19G显存 (NVIDIA 4090) |
| cogvideox-t2v-2b | sft | 8卡zero-3 + opt and param offload | bf16 | 49x480x720 | 14G显存 (NVIDIA 4080) |
| cogvideox-{t2v, i2v}-5b | sft | 1卡zero-2 + opt offload | bf16 | 49x480x720 | 42G显存 (NVIDIA A100) |
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-2 | bf16 | 49x480x720 | 42G显存 (NVIDIA 4090) |
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-3 | bf16 | 49x480x720 | 43G显存 (NVIDIA 4090) |
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-3 + opt and param offload | bf16 | 49x480x720 | 28G显存 (NVIDIA 5090) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 1卡zero-2 + opt offload | bf16 | 81x768x1360 | 56G显存 (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-2 | bf16 | 81x768x1360 | 55G显存 (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-3 | bf16 | 81x768x1360 | 55G显存 (NVIDIA A100) |
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-3 + opt and param offload | bf16 | 81x768x1360 | 40G显存 (NVIDIA A100) |
## 安装依赖
@ -24,89 +40,83 @@ pip install -e .
## 准备数据集
首先,你需要准备数据集数据集格式如下其中videos.txt 存放 videos 中的视频。
首先,你需要准备数据集。根据你的任务类型T2V 或 I2V数据集格式略有不同
```
.
├── prompts.txt
├── videos
└── videos.txt
├── videos.txt
├── images # (可选) 对于I2V若不提供则从视频中提取第一帧作为参考图像
└── images.txt # (可选) 对于I2V若不提供则从视频中提取第一帧作为参考图像
```
你可以从这里下载 [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
其中:
- `prompts.txt`: 存放提示词
- `videos/`: 存放.mp4视频文件
- `videos.txt`: 存放 videos 目录中的视频文件列表
- `images/`: (可选) 存放.png参考图像文件
- `images.txt`: (可选) 存放参考图像文件列表
视频微调数据集作为测试微调。
你可以从这里下载示例数据集(T2V) [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
## 配置文件和运行
`accelerate` 配置文件如下:
+ accelerate_config_machine_multi.yaml 适合多GPU使用
+ accelerate_config_machine_single.yaml 适合单GPU使用
`finetune` 脚本配置文件如下:
```shell
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # 使用 accelerate 启动多GPU训练配置文件为 accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # 运行的训练脚本为 train_cogvideox_lora.py用于在 CogVideoX 模型上进行 LoRA 微调
--gradient_checkpointing \ # 启用梯度检查点功能,以减少显存使用
--pretrained_model_name_or_path $MODEL_PATH \ # 预训练模型路径,通过 $MODEL_PATH 指定
--cache_dir $CACHE_PATH \ # 模型缓存路径,由 $CACHE_PATH 指定
--enable_tiling \ # 启用tiling技术以分片处理视频节省显存
--enable_slicing \ # 启用slicing技术将输入切片以进一步优化内存
--instance_data_root $DATASET_PATH \ # 数据集路径,由 $DATASET_PATH 指定
--caption_column prompts.txt \ # 指定用于训练的视频描述文件,文件名为 prompts.txt
--video_column videos.txt \ # 指定用于训练的视频路径文件,文件名为 videos.txt
--validation_prompt "" \ # 验证集的提示语 (prompt),用于在训练期间生成验证视频
--validation_prompt_separator ::: \ # 设置验证提示语的分隔符为 :::
--num_validation_videos 1 \ # 每个验证回合生成 1 个视频
--validation_epochs 100 \ # 每 100 个训练epoch进行一次验证
--seed 42 \ # 设置随机种子为 42以保证结果的可复现性
--rank 128 \ # 设置 LoRA 参数的秩 (rank) 为 128
--lora_alpha 64 \ # 设置 LoRA 的 alpha 参数为 64用于调整LoRA的学习率
--mixed_precision bf16 \ # 使用 bf16 混合精度进行训练,减少显存使用
--output_dir $OUTPUT_PATH \ # 指定模型输出目录,由 $OUTPUT_PATH 定义
--height 480 \ # 视频高度为 480 像素
--width 720 \ # 视频宽度为 720 像素
--fps 8 \ # 视频帧率设置为 8 帧每秒
--max_num_frames 49 \ # 每个视频的最大帧数为 49 帧
--skip_frames_start 0 \ # 跳过视频开头的帧数为 0
--skip_frames_end 0 \ # 跳过视频结尾的帧数为 0
--train_batch_size 4 \ # 训练时的 batch size 设置为 4
--num_train_epochs 30 \ # 总训练epoch数为 30
--checkpointing_steps 1000 \ # 每 1000 步保存一次模型检查点
--gradient_accumulation_steps 1 \ # 梯度累计步数为 1即每个 batch 后都会更新梯度
--learning_rate 1e-3 \ # 学习率设置为 0.001
--lr_scheduler cosine_with_restarts \ # 使用带重启的余弦学习率调度器
--lr_warmup_steps 200 \ # 在训练的前 200 步进行学习率预热
--lr_num_cycles 1 \ # 学习率周期设置为 1
--optimizer AdamW \ # 使用 AdamW 优化器
--adam_beta1 0.9 \ # 设置 Adam 优化器的 beta1 参数为 0.9
--adam_beta2 0.95 \ # 设置 Adam 优化器的 beta2 参数为 0.95
--max_grad_norm 1.0 \ # 最大梯度裁剪值设置为 1.0
--allow_tf32 \ # 启用 TF32 以加速训练
--report_to wandb # 使用 Weights and Biases 进行训练记录与监控
```
如果需要在训练过程中进行validation则需要额外提供验证数据集其中数据格式与训练集相同。
## 运行脚本,开始微调
单机(单卡,多卡)微调
在开始训练之前,请注意以下分辨率设置要求:
```shell
bash finetune_single_rank.sh
1. 帧数必须是8的倍数 **+1** (即8N+1), 例如49, 81 ...
2. 视频分辨率建议使用模型的默认大小:
- CogVideoX: 480x720 (高x宽)
- CogVideoX1.5: 768x1360 (高x宽)
3. 对于不满足训练分辨率的样本视频或图片在代码中会直接进行resize。这可能会导致样本的宽高比发生形变从而影响训练效果。建议用户提前对样本在分辨率上进行处理例如使用crop + resize来维持宽高比再进行训练。
> **重要提示**为了提高训练效率我们会在训练前自动对video进行encode并将结果缓存在磁盘。如果在训练后修改了数据请删除video目录下的latent目录以确保使用最新的数据。
### LoRA
```bash
# 修改 train_ddp_t2v.sh 中的配置参数
# 主要需要修改以下参数:
# --output_dir: 输出目录
# --data_root: 数据集根目录
# --caption_column: 提示词文件路径
# --image_column: I2V可选参考图像文件列表路径 (移除这个参数将默认使用视频第一帧作为image condition)
# --video_column: 视频文件列表路径
# --train_resolution: 训练分辨率 (帧数x高x宽)
# 其他重要参数请参考启动脚本
bash train_ddp_t2v.sh # 文本生成视频 (T2V) 微调
bash train_ddp_i2v.sh # 图像生成视频 (I2V) 微调
```
多机多卡微调:
### SFT
```shell
bash finetune_multi_rank.sh #需要在每个节点运行
我们在`configs/`目录中提供了几个zero配置的模版请根据你的需求选择合适的训练配置`accelerate_config.yaml`中配置`deepspeed_config_file`选项即可)。
```bash
# 需要配置的参数与LoRA训练同理
bash train_zero_t2v.sh # 文本生成视频 (T2V) 微调
bash train_zero_i2v.sh # 图像生成视频 (I2V) 微调
```
除了设置bash脚本的相关参数你还需要在zero的配置文件中设定相关的训练选项并确保zero的训练配置与bash脚本中的参数一致例如batch_sizegradient_accumulation_stepsmixed_precision具体细节请参考[deepspeed官方文档](https://www.deepspeed.ai/docs/config-json/)
在使用sft训练时有以下几点需要注意
1. 对于sft训练validation时不会使用model offload因此显存峰值可能会超出24GB所以对于24GB以下的显卡建议关闭validation。
2. 开启zero-3时validation会比较慢建议在zero-3下关闭validation。
## 载入微调的模型
+ 请关注[cli_demo.py](../inference/cli_demo.py) 以了解如何加载微调的模型。
+ 对于sft训练的模型请先使用`checkpoint-*/`目录下的`zero_to_fp32.py`脚本合并模型权重
## 最佳实践
+ 包含70个分辨率为 `200 x 480 x 720`(帧数 x 高 x
@ -116,4 +126,3 @@ bash finetune_multi_rank.sh #需要在每个节点运行
+ 原始仓库使用 `lora_alpha` 设置为 1。我们发现这个值在多次运行中效果不佳可能是因为模型后端和训练设置的不同。我们的建议是将
lora_alpha 设置为与 rank 相同或 rank // 2。
+ 建议使用 rank 为 64 及以上的设置。

View File

@ -1,21 +1,18 @@
compute_environment: LOCAL_MACHINE
gpu_ids: "0,1,2,3,4,5,6,7"
num_processes: 8 # should be the same as the number of GPUs
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
deepspeed_config_file: configs/zero2.yaml # e.g. configs/zero2.yaml, need use absolute path
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []

View File

@ -1,26 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
deepspeed_hostfile: hostfile.txt
deepspeed_multinode_launcher: pdsh
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
enable_cpu_affinity: true
main_process_ip: 10.250.128.19
main_process_port: 12355
main_training_function: main
mixed_precision: bf16
num_machines: 4
num_processes: 32
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,38 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,42 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,43 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto",
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e5
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,51 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto",
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

2
finetune/constants.py Normal file
View File

@ -0,0 +1,2 @@
LOG_NAME = "trainer"
LOG_LEVEL = "INFO"

View File

@ -0,0 +1,12 @@
from .bucket_sampler import BucketSampler
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
__all__ = [
"I2VDatasetWithResize",
"I2VDatasetWithBuckets",
"T2VDatasetWithResize",
"T2VDatasetWithBuckets",
"BucketSampler",
]

View File

@ -0,0 +1,79 @@
import logging
import random
from torch.utils.data import Dataset, Sampler
logger = logging.getLogger(__name__)
class BucketSampler(Sampler):
r"""
PyTorch Sampler that groups 3D data by height, width and frames.
Args:
data_source (`VideoDataset`):
A PyTorch dataset object that is an instance of `VideoDataset`.
batch_size (`int`, defaults to `8`):
The batch size to use for training.
shuffle (`bool`, defaults to `True`):
Whether or not to shuffle the data in each batch before dispatching to dataloader.
drop_last (`bool`, defaults to `False`):
Whether or not to drop incomplete buckets of data after completely iterating over all data
in the dataset. If set to True, only batches that have `batch_size` number of entries will
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
and batches that do not have `batch_size` number of entries will also be yielded.
"""
def __init__(
self,
data_source: Dataset,
batch_size: int = 8,
shuffle: bool = True,
drop_last: bool = False,
) -> None:
self.data_source = data_source
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets}
self._raised_warning_for_drop_last = False
def __len__(self):
if self.drop_last and not self._raised_warning_for_drop_last:
self._raised_warning_for_drop_last = True
logger.warning(
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
)
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
def __iter__(self):
for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"]
f, h, w = (
video_metadata["num_frames"],
video_metadata["height"],
video_metadata["width"],
)
self.buckets[(f, h, w)].append(data)
if len(self.buckets[(f, h, w)]) == self.batch_size:
if self.shuffle:
random.shuffle(self.buckets[(f, h, w)])
yield self.buckets[(f, h, w)]
del self.buckets[(f, h, w)]
self.buckets[(f, h, w)] = []
if self.drop_last:
return
for fhw, bucket in list(self.buckets.items()):
if len(bucket) == 0:
continue
if self.shuffle:
random.shuffle(bucket)
yield bucket
del self.buckets[fhw]
self.buckets[fhw] = []

View File

@ -0,0 +1,325 @@
import hashlib
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from typing_extensions import override
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import (
load_images,
load_images_from_videos,
load_prompts,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_buckets,
preprocess_video_with_resize,
)
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger(LOG_NAME, LOG_LEVEL)
class BaseI2VDataset(Dataset):
"""
Base dataset class for Image-to-Video (I2V) training.
This dataset loads prompts, videos and corresponding conditioning images for I2V training.
Args:
data_root (str): Root directory containing the dataset files
caption_column (str): Path to file containing text prompts/captions
video_column (str): Path to file containing video paths
image_column (str): Path to file containing image paths
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
caption_column: str,
video_column: str,
image_column: str | None,
device: torch.device,
trainer: "Trainer" = None,
*args,
**kwargs,
) -> None:
super().__init__()
data_root = Path(data_root)
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
if image_column is not None:
self.images = load_images(data_root / image_column)
else:
self.images = load_images_from_videos(self.videos)
self.trainer = trainer
self.device = device
self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
# Check if number of prompts matches number of videos and images
if not (len(self.videos) == len(self.prompts) == len(self.images)):
raise ValueError(
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
)
# Check if all video files exist
if any(not path.is_file() for path in self.videos):
raise ValueError(
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
)
# Check if all image files exist
if any(not path.is_file() for path in self.images):
raise ValueError(
f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
)
def __len__(self) -> int:
return len(self.videos)
def __getitem__(self, index: int) -> Dict[str, Any]:
if isinstance(index, list):
# Here, index is actually a list of data objects that we need to return.
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
# to have information about num_frames, height and width. Since this is not stored
# as metadata, we need to read the video to get this information. You could read this
# information without loading the full video in memory, but we do it anyway. In order
# to not load the video twice (once to get the metadata, and once to return the loaded video
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
# to yield, we yield the cache data instead of indices. So, this special check ensures
# that data is not loaded a second time. PRs are welcome for improvements.
return index
prompt = self.prompts[index]
video = self.videos[index]
image = self.images[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = (
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
)
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
)
if encoded_video_path.exists():
encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W]
_, image = self.preprocess(None, self.images[index])
image = self.image_transform(image)
else:
frames, image = self.preprocess(video, image)
frames = frames.to(self.device)
image = image.to(self.device)
image = self.image_transform(image)
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video(frames)
# [1, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
image = image.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
# shape of image: [C, H, W]
return {
"image": image,
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],
"height": encoded_video.shape[2],
"width": encoded_video.shape[3],
},
}
def preprocess(
self, video_path: Path | None, image_path: Path | None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Loads and preprocesses a video and an image.
If either path is None, no preprocessing will be done for that input.
Args:
video_path: Path to the video file to load
image_path: Path to the image file to load
Returns:
A tuple containing:
- video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
C is number of channels, H is height and W is width
- image(torch.Tensor) of shape [C, H, W]
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
Args:
frames (torch.Tensor): A 4D tensor representing a video
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- W is width
Returns:
torch.Tensor: The transformed video tensor
"""
raise NotImplementedError("Subclass must implement this method")
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to an image.
Args:
image (torch.Tensor): A 3D tensor representing an image
with shape [C, H, W] where:
- C is number of channels (3 for RGB)
- H is height
- W is width
Returns:
torch.Tensor: The transformed image tensor
"""
raise NotImplementedError("Subclass must implement this method")
class I2VDatasetWithResize(BaseI2VDataset):
"""
A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
This class preprocesses videos and images by resizing them to specified dimensions:
- Videos are resized to max_num_frames x height x width
- Images are resized to height x width
Args:
max_num_frames (int): Maximum number of frames to extract from videos
height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.max_num_frames = max_num_frames
self.height = height
self.width = width
self.__frame_transforms = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
self.__image_transforms = self.__frame_transforms
@override
def preprocess(
self, video_path: Path | None, image_path: Path | None
) -> Tuple[torch.Tensor, torch.Tensor]:
if video_path is not None:
video = preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width
)
else:
video = None
if image_path is not None:
image = preprocess_image_with_resize(image_path, self.height, self.width)
else:
image = None
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)
class I2VDatasetWithBuckets(BaseI2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.video_resolution_buckets = [
(
int(b[0] / vae_temporal_compression_ratio),
int(b[1] / vae_height_compression_ratio),
int(b[2] / vae_width_compression_ratio),
)
for b in video_resolution_buckets
]
self.__frame_transforms = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
self.__image_transforms = self.__frame_transforms
@override
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)

View File

@ -0,0 +1,264 @@
import hashlib
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from typing_extensions import override
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import (
load_prompts,
load_videos,
preprocess_video_with_buckets,
preprocess_video_with_resize,
)
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger(LOG_NAME, LOG_LEVEL)
class BaseT2VDataset(Dataset):
"""
Base dataset class for Text-to-Video (T2V) training.
This dataset loads prompts and videos for T2V training.
Args:
data_root (str): Root directory containing the dataset files
caption_column (str): Path to file containing text prompts/captions
video_column (str): Path to file containing video paths
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
caption_column: str,
video_column: str,
device: torch.device = None,
trainer: "Trainer" = None,
*args,
**kwargs,
) -> None:
super().__init__()
data_root = Path(data_root)
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.device = device
self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
self.trainer = trainer
# Check if all video files exist
if any(not path.is_file() for path in self.videos):
raise ValueError(
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
)
# Check if number of prompts matches number of videos
if len(self.videos) != len(self.prompts):
raise ValueError(
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
)
def __len__(self) -> int:
return len(self.videos)
def __getitem__(self, index: int) -> Dict[str, Any]:
if isinstance(index, list):
# Here, index is actually a list of data objects that we need to return.
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
# to have information about num_frames, height and width. Since this is not stored
# as metadata, we need to read the video to get this information. You could read this
# information without loading the full video in memory, but we do it anyway. In order
# to not load the video twice (once to get the metadata, and once to return the loaded video
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
# to yield, we yield the cache data instead of indices. So, this special check ensures
# that data is not loaded a second time. PRs are welcome for improvements.
return index
prompt = self.prompts[index]
video = self.videos[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = (
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
)
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
)
if encoded_video_path.exists():
# encoded_video = torch.load(encoded_video_path, weights_only=True)
encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W]
else:
frames = self.preprocess(video)
frames = frames.to(self.device)
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video(frames)
# [1, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
return {
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],
"height": encoded_video.shape[2],
"width": encoded_video.shape[3],
},
}
def preprocess(self, video_path: Path) -> torch.Tensor:
"""
Loads and preprocesses a video.
Args:
video_path: Path to the video file to load.
Returns:
torch.Tensor: Video tensor of shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- W is width
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
Args:
frames (torch.Tensor): A 4D tensor representing a video
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- W is width
Returns:
torch.Tensor: The transformed video tensor with the same shape as the input
"""
raise NotImplementedError("Subclass must implement this method")
class T2VDatasetWithResize(BaseT2VDataset):
"""
A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
This class preprocesses videos by resizing them to specified dimensions:
- Videos are resized to max_num_frames x height x width
Args:
max_num_frames (int): Maximum number of frames to extract from videos
height (int): Target height for resizing videos
width (int): Target width for resizing videos
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.max_num_frames = max_num_frames
self.height = height
self.width = width
self.__frame_transform = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_resize(
video_path,
self.max_num_frames,
self.height,
self.width,
)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
class T2VDatasetWithBuckets(BaseT2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args,
**kwargs,
) -> None:
""" """
super().__init__(*args, **kwargs)
self.video_resolution_buckets = [
(
int(b[0] / vae_temporal_compression_ratio),
int(b[1] / vae_height_compression_ratio),
int(b[2] / vae_width_compression_ratio),
)
for b in video_resolution_buckets
]
self.__frame_transform = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)

196
finetune/datasets/utils.py Normal file
View File

@ -0,0 +1,196 @@
import logging
from pathlib import Path
from typing import List, Tuple
import cv2
import torch
from torchvision.transforms.functional import resize
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
decord.bridge.set_bridge("torch")
########## loaders ##########
def load_prompts(prompt_path: Path) -> List[str]:
with open(prompt_path, "r", encoding="utf-8") as file:
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
def load_videos(video_path: Path) -> List[Path]:
with open(video_path, "r", encoding="utf-8") as file:
return [
video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
]
def load_images(image_path: Path) -> List[Path]:
with open(image_path, "r", encoding="utf-8") as file:
return [
image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
]
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
first_frames_dir = videos_path[0].parent.parent / "first_frames"
first_frames_dir.mkdir(exist_ok=True)
first_frame_paths = []
for video_path in videos_path:
frame_path = first_frames_dir / f"{video_path.stem}.png"
if frame_path.exists():
first_frame_paths.append(frame_path)
continue
# Open video
cap = cv2.VideoCapture(str(video_path))
# Read first frame
ret, frame = cap.read()
if not ret:
raise RuntimeError(f"Failed to read video: {video_path}")
# Save frame as PNG with same name as video
cv2.imwrite(str(frame_path), frame)
logging.info(f"Saved first frame to {frame_path}")
# Release video capture
cap.release()
first_frame_paths.append(frame_path)
return first_frame_paths
########## preprocessors ##########
def preprocess_image_with_resize(
image_path: Path | str,
height: int,
width: int,
) -> torch.Tensor:
"""
Loads and resizes a single image.
Args:
image_path: Path to the image file.
height: Target height for resizing.
width: Target width for resizing.
Returns:
torch.Tensor: Image tensor with shape [C, H, W] where:
C = number of channels (3 for RGB)
H = height
W = width
"""
if isinstance(image_path, str):
image_path = Path(image_path)
image = cv2.imread(image_path.as_posix())
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (width, height))
image = torch.from_numpy(image).float()
image = image.permute(2, 0, 1).contiguous()
return image
def preprocess_video_with_resize(
video_path: Path | str,
max_num_frames: int,
height: int,
width: int,
) -> torch.Tensor:
"""
Loads and resizes a single video.
The function processes the video through these steps:
1. If video frame count > max_num_frames, downsample frames evenly
2. If video dimensions don't match (height, width), resize frames
Args:
video_path: Path to the video file.
max_num_frames: Maximum number of frames to keep.
height: Target height for resizing.
width: Target width for resizing.
Returns:
A torch.Tensor with shape [F, C, H, W] where:
F = number of frames
C = number of channels (3 for RGB)
H = height
W = width
"""
if isinstance(video_path, str):
video_path = Path(video_path)
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
video_num_frames = len(video_reader)
if video_num_frames < max_num_frames:
# Get all frames first
frames = video_reader.get_batch(list(range(video_num_frames)))
# Repeat the last frame until we reach max_num_frames
last_frame = frames[-1:]
num_repeats = max_num_frames - video_num_frames
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)
frames = torch.cat([frames, repeated_frames], dim=0)
return frames.float().permute(0, 3, 1, 2).contiguous()
else:
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)
frames = frames[:max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
return frames
def preprocess_video_with_buckets(
video_path: Path,
resolution_buckets: List[Tuple[int, int, int]],
) -> torch.Tensor:
"""
Args:
video_path: Path to the video file.
resolution_buckets: List of tuples (num_frames, height, width) representing
available resolution buckets.
Returns:
torch.Tensor: Video tensor with shape [F, C, H, W] where:
F = number of frames
C = number of channels (3 for RGB)
H = height
W = width
The function processes the video through these steps:
1. Finds nearest frame bucket <= video frame count
2. Downsamples frames evenly to match bucket size
3. Finds nearest resolution bucket based on dimensions
4. Resizes frames to match bucket resolution
"""
video_reader = decord.VideoReader(uri=video_path.as_posix())
video_num_frames = len(video_reader)
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
if len(resolution_buckets) == 0:
raise ValueError(
f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}"
)
nearest_frame_bucket = min(
resolution_buckets,
key=lambda bucket: video_num_frames - bucket[0],
default=1,
)[0]
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
frames = video_reader.get_batch(frame_indices)
frames = frames[:nearest_frame_bucket].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
nearest_res = min(
resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])
)
nearest_res = (nearest_res[1], nearest_res[2])
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
return frames

View File

@ -1,52 +0,0 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-multi-node"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# max batch-size is 2.
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--video_column videos.txt \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 100 \
--seed 42 \
--rank 128 \
--lora_alpha 64 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb

View File

@ -1,52 +0,0 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-single-node"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--video_column videos.txt \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 100 \
--seed 42 \
--rank 128 \
--lora_alpha 64 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--report_to wandb

View File

@ -1,2 +0,0 @@
node1 slots=8
node2 slots=8

View File

@ -0,0 +1,12 @@
import importlib
from pathlib import Path
package_dir = Path(__file__).parent
for subdir in package_dir.iterdir():
if subdir.is_dir() and not subdir.name.startswith("_"):
for module_path in subdir.glob("*.py"):
module_name = module_path.stem
full_module_name = f".{subdir.name}.{module_name}"
importlib.import_module(full_module_name, package=__name__)

View File

@ -0,0 +1,9 @@
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoX1_5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
pass
register("cogvideox1.5-i2v", "lora", CogVideoX1_5I2VLoraTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer
from ..utils import register
class CogVideoX1_5I2VSftTrainer(CogVideoXI2VSftTrainer):
pass
register("cogvideox1.5-i2v", "sft", CogVideoX1_5I2VSftTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
class CogVideoX1_5T2VLoraTrainer(CogVideoXT2VLoraTrainer):
pass
register("cogvideox1.5-t2v", "lora", CogVideoX1_5T2VLoraTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer
from ..utils import register
class CogVideoX1_5T2VSftTrainer(CogVideoXT2VSftTrainer):
pass
register("cogvideox1.5-t2v", "sft", CogVideoX1_5T2VSftTrainer)

View File

@ -0,0 +1,272 @@
from typing import Any, Dict, List, Tuple
import torch
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
)
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from numpy import dtype
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model
from ..utils import register
class CogVideoXI2VLoraTrainer(Trainer):
UNLOAD_LIST = ["text_encoder"]
@override
def load_components(self) -> Dict[str, Any]:
components = Components()
model_path = str(self.args.model_path)
components.pipeline_cls = CogVideoXImageToVideoPipeline
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components
@override
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler,
)
return pipe
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
vae = self.components.vae
video = video.to(vae.device, dtype=vae.dtype)
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def encode_text(self, prompt: str) -> torch.Tensor:
prompt_token_ids = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(
prompt_token_ids.to(self.accelerator.device)
)[0]
return prompt_embedding
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt_embedding = sample["prompt_embedding"]
image = sample["image"]
ret["encoded_videos"].append(encoded_video)
ret["prompt_embedding"].append(prompt_embedding)
ret["images"].append(image)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
ret["images"] = torch.stack(ret["images"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_embedding = batch["prompt_embedding"]
latent = batch["encoded_videos"]
images = batch["images"]
# Shape of prompt_embedding: [B, seq_len, hidden_size]
# Shape of latent: [B, C, F, H, W]
# Shape of images: [B, C, H, W]
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None:
ncopy = latent.shape[2] % patch_size_t
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
assert latent.shape[2] % patch_size_t == 0
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
_, seq_len, _ = prompt_embedding.shape
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
images = images.unsqueeze(2)
# Add noise to images
image_noise_sigma = torch.normal(
mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device
)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
noisy_images = (
images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
)
image_latent_dist = self.components.vae.encode(
noisy_images.to(dtype=self.components.vae.dtype)
).latent_dist
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
# Sample a random timestep for each sample
timesteps = torch.randint(
0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
device=self.accelerator.device,
)
timesteps = timesteps.long()
# from [B, C, F, H, W] to [B, F, C, H, W]
latent = latent.permute(0, 2, 1, 3, 4)
image_latents = image_latents.permute(0, 2, 1, 3, 4)
assert (latent.shape[0], *latent.shape[2:]) == (
image_latents.shape[0],
*image_latents.shape[2:],
)
# Padding image_latents to the same frame number as latent
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
latent_padding = image_latents.new_zeros(padding_shape)
image_latents = torch.cat([image_latents, latent_padding], dim=1)
# Add noise to latent
noise = torch.randn_like(latent)
latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
# Concatenate latent and image_latents in the channel dimension
latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
# Prepare rotary embeds
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
transformer_config = self.state.transformer_config
rotary_emb = (
self.prepare_rotary_positional_embeddings(
height=height * vae_scale_factor_spatial,
width=width * vae_scale_factor_spatial,
num_frames=num_frames,
transformer_config=transformer_config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
device=self.accelerator.device,
)
if transformer_config.use_rotary_positional_embeddings
else None
)
# Predict noise, For CogVideoX1.5 Only.
ofs_emb = (
None
if self.state.transformer_config.ofs_embed_dim is None
else latent.new_full((1,), fill_value=2.0)
)
predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy,
encoder_hidden_states=prompt_embedding,
timestep=timesteps,
ofs=ofs_emb,
image_rotary_emb=rotary_emb,
return_dict=False,
)[0]
# Denoise
latent_pred = self.components.scheduler.get_velocity(
predicted_noise, latent_noisy, timesteps
)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
while len(weights.shape) < len(latent_pred.shape):
weights = weights.unsqueeze(-1)
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
return loss
@override
def validation_step(
self, eval_data: Dict[str, Any], pipe: CogVideoXImageToVideoPipeline
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
"""
Return the data that needs to be saved. For videos, the data format is List[PIL],
and for images, the data format is PIL
"""
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
video_generate = pipe(
num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
image=image,
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (
num_frames + transformer_config.patch_size_t - 1
) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(grid_height, grid_width),
device=device,
)
return freqs_cos, freqs_sin
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
pass
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)

View File

@ -0,0 +1,228 @@
from typing import Any, Dict, List, Tuple
import torch
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model
from ..utils import register
class CogVideoXT2VLoraTrainer(Trainer):
UNLOAD_LIST = ["text_encoder", "vae"]
@override
def load_components(self) -> Components:
components = Components()
model_path = str(self.args.model_path)
components.pipeline_cls = CogVideoXPipeline
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components
@override
def initialize_pipeline(self) -> CogVideoXPipeline:
pipe = CogVideoXPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler,
)
return pipe
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
vae = self.components.vae
video = video.to(vae.device, dtype=vae.dtype)
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def encode_text(self, prompt: str) -> torch.Tensor:
prompt_token_ids = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(
prompt_token_ids.to(self.accelerator.device)
)[0]
return prompt_embedding
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {"encoded_videos": [], "prompt_embedding": []}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt_embedding = sample["prompt_embedding"]
ret["encoded_videos"].append(encoded_video)
ret["prompt_embedding"].append(prompt_embedding)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_embedding = batch["prompt_embedding"]
latent = batch["encoded_videos"]
# Shape of prompt_embedding: [B, seq_len, hidden_size]
# Shape of latent: [B, C, F, H, W]
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None:
ncopy = latent.shape[2] % patch_size_t
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
assert latent.shape[2] % patch_size_t == 0
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
_, seq_len, _ = prompt_embedding.shape
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
# Sample a random timestep for each sample
timesteps = torch.randint(
0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
device=self.accelerator.device,
)
timesteps = timesteps.long()
# Add noise to latent
latent = latent.permute(0, 2, 1, 3, 4) # from [B, C, F, H, W] to [B, F, C, H, W]
noise = torch.randn_like(latent)
latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps)
# Prepare rotary embeds
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
transformer_config = self.state.transformer_config
rotary_emb = (
self.prepare_rotary_positional_embeddings(
height=height * vae_scale_factor_spatial,
width=width * vae_scale_factor_spatial,
num_frames=num_frames,
transformer_config=transformer_config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
device=self.accelerator.device,
)
if transformer_config.use_rotary_positional_embeddings
else None
)
# Predict noise
predicted_noise = self.components.transformer(
hidden_states=latent_added_noise,
encoder_hidden_states=prompt_embedding,
timestep=timesteps,
image_rotary_emb=rotary_emb,
return_dict=False,
)[0]
# Denoise
latent_pred = self.components.scheduler.get_velocity(
predicted_noise, latent_added_noise, timesteps
)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
while len(weights.shape) < len(latent_pred.shape):
weights = weights.unsqueeze(-1)
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
loss = loss.mean()
return loss
@override
def validation_step(
self, eval_data: Dict[str, Any], pipe: CogVideoXPipeline
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
"""
Return the data that needs to be saved. For videos, the data format is List[PIL],
and for images, the data format is PIL
"""
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
video_generate = pipe(
num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (
num_frames + transformer_config.patch_size_t - 1
) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(grid_height, grid_width),
device=device,
)
return freqs_cos, freqs_sin
register("cogvideox-t2v", "lora", CogVideoXT2VLoraTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
pass
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)

59
finetune/models/utils.py Normal file
View File

@ -0,0 +1,59 @@
from typing import Dict, Literal
from finetune.trainer import Trainer
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
"""Register a model and its associated functions for a specific training type.
Args:
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
trainer_cls (Trainer): Trainer class to register.
"""
# Check if model_name and training_type exists in SUPPORTED_MODELS
if model_name not in SUPPORTED_MODELS:
SUPPORTED_MODELS[model_name] = {}
else:
if training_type in SUPPORTED_MODELS[model_name]:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
def show_supported_models():
"""Print all currently supported models and their training types."""
print("\nSupported Models:")
print("================")
for model_name, training_types in SUPPORTED_MODELS.items():
print(f"\n{model_name}")
print("-" * len(model_name))
for training_type in training_types:
print(f"{training_type}")
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
"""Get the trainer class for a specific model and training type."""
if model_type not in SUPPORTED_MODELS:
print(f"\nModel '{model_type}' is not supported.")
print("\nSupported models are:")
for supported_model in SUPPORTED_MODELS:
print(f"{supported_model}")
raise ValueError(f"Model '{model_type}' is not supported")
if training_type not in SUPPORTED_MODELS[model_type]:
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
print(f"\nSupported training types for '{model_type}' are:")
for supported_type in SUPPORTED_MODELS[model_type]:
print(f"{supported_type}")
raise ValueError(
f"Training type '{training_type}' is not supported for model '{model_type}'"
)
return SUPPORTED_MODELS[model_type][training_type]

View File

@ -0,0 +1,6 @@
from .args import Args
from .components import Components
from .state import State
__all__ = ["Args", "State", "Components"]

254
finetune/schemas/args.py Normal file
View File

@ -0,0 +1,254 @@
import argparse
import datetime
import logging
from pathlib import Path
from typing import Any, List, Literal, Tuple
from pydantic import BaseModel, ValidationInfo, field_validator
class Args(BaseModel):
########## Model ##########
model_path: Path
model_name: str
model_type: Literal["i2v", "t2v"]
training_type: Literal["lora", "sft"] = "lora"
########## Output ##########
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now()))
report_to: Literal["tensorboard", "wandb", "all"] | None = None
tracker_name: str = "finetrainer-cogvideo"
########## Data ###########
data_root: Path
caption_column: Path
image_column: Path | None = None
video_column: Path
########## Training #########
resume_from_checkpoint: Path | None = None
seed: int | None = None
train_epochs: int
train_steps: int | None = None
checkpointing_steps: int = 200
checkpointing_limit: int = 10
batch_size: int
gradient_accumulation_steps: int = 1
train_resolution: Tuple[int, int, int] # shape: (frames, height, width)
#### deprecated args: video_resolution_buckets
# if use bucket for training, should not be None
# Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset
# Note2: For cogvideox, cogvideox1.5
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
# video_resolution_buckets: List[Tuple[int, int, int]] | None = None
mixed_precision: Literal["no", "fp16", "bf16"]
learning_rate: float = 2e-5
optimizer: str = "adamw"
beta1: float = 0.9
beta2: float = 0.95
beta3: float = 0.98
epsilon: float = 1e-8
weight_decay: float = 1e-4
max_grad_norm: float = 1.0
lr_scheduler: str = "constant_with_warmup"
lr_warmup_steps: int = 100
lr_num_cycles: int = 1
lr_power: float = 1.0
num_workers: int = 8
pin_memory: bool = True
gradient_checkpointing: bool = True
enable_slicing: bool = True
enable_tiling: bool = True
nccl_timeout: int = 1800
########## Lora ##########
rank: int = 128
lora_alpha: int = 64
target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"]
########## Validation ##########
do_validation: bool = False
validation_steps: int | None # if set, should be a multiple of checkpointing_steps
validation_dir: Path | None # if set do_validation, should not be None
validation_prompts: str | None # if set do_validation, should not be None
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
gen_fps: int = 15
#### deprecated args: gen_video_resolution
# 1. If set do_validation, should not be None
# 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning
# or the resolution recommended by the model
# 3. Note: For cogvideox, cogvideox1.5
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
# gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width)
@field_validator("image_column")
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("model_type") == "i2v" and not v:
logging.warning(
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images."
)
return v
@field_validator("validation_dir", "validation_prompts")
def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any:
values = info.data
if values.get("do_validation") and not v:
field_name = info.field_name
raise ValueError(f"{field_name} must be specified when do_validation is True")
return v
@field_validator("validation_images")
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
raise ValueError(
"validation_images must be specified when do_validation is True and model_type is i2v"
)
return v
@field_validator("validation_videos")
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
raise ValueError(
"validation_videos must be specified when do_validation is True and model_type is v2v"
)
return v
@field_validator("validation_steps")
def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None:
values = info.data
if values.get("do_validation"):
if v is None:
raise ValueError("validation_steps must be specified when do_validation is True")
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
return v
@field_validator("train_resolution")
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
try:
frames, height, width = v
# Check if (frames - 1) is multiple of 8
if (frames - 1) % 8 != 0:
raise ValueError("Number of frames - 1 must be a multiple of 8")
# Check resolution for cogvideox-5b models
model_name = info.data.get("model_name", "")
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
if (height, width) != (480, 720):
raise ValueError(
"For cogvideox-5b models, height must be 480 and width must be 720"
)
return v
except ValueError as e:
if (
str(e) == "not enough values to unpack (expected 3, got 0)"
or str(e) == "invalid literal for int() with base 10"
):
raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e
@field_validator("mixed_precision")
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str:
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower():
logging.warning(
"All CogVideoX models except cogvideox-2b were trained with bfloat16. "
"Using fp16 precision may lead to training instability."
)
return v
@classmethod
def parse_args(cls):
"""Parse command line arguments and return Args instance"""
parser = argparse.ArgumentParser()
# Required arguments
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--model_type", type=str, required=True)
parser.add_argument("--training_type", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--data_root", type=str, required=True)
parser.add_argument("--caption_column", type=str, required=True)
parser.add_argument("--video_column", type=str, required=True)
parser.add_argument("--train_resolution", type=str, required=True)
parser.add_argument("--report_to", type=str, required=True)
# Training hyperparameters
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--train_epochs", type=int, default=10)
parser.add_argument("--train_steps", type=int, default=None)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--optimizer", type=str, default="adamw")
parser.add_argument("--beta1", type=float, default=0.9)
parser.add_argument("--beta2", type=float, default=0.95)
parser.add_argument("--beta3", type=float, default=0.98)
parser.add_argument("--epsilon", type=float, default=1e-8)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
# Learning rate scheduler
parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=100)
parser.add_argument("--lr_num_cycles", type=int, default=1)
parser.add_argument("--lr_power", type=float, default=1.0)
# Data loading
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--pin_memory", type=bool, default=True)
parser.add_argument("--image_column", type=str, default=None)
# Model configuration
parser.add_argument("--mixed_precision", type=str, default="no")
parser.add_argument("--gradient_checkpointing", type=bool, default=True)
parser.add_argument("--enable_slicing", type=bool, default=True)
parser.add_argument("--enable_tiling", type=bool, default=True)
parser.add_argument("--nccl_timeout", type=int, default=1800)
# LoRA parameters
parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument(
"--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]
)
# Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200)
parser.add_argument("--checkpointing_limit", type=int, default=10)
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
# Validation
parser.add_argument("--do_validation", type=lambda x: x.lower() == 'true', default=False)
parser.add_argument("--validation_steps", type=int, default=None)
parser.add_argument("--validation_dir", type=str, default=None)
parser.add_argument("--validation_prompts", type=str, default=None)
parser.add_argument("--validation_images", type=str, default=None)
parser.add_argument("--validation_videos", type=str, default=None)
parser.add_argument("--gen_fps", type=int, default=15)
args = parser.parse_args()
# Convert video_resolution_buckets string to list of tuples
frames, height, width = args.train_resolution.split("x")
args.train_resolution = (int(frames), int(height), int(width))
return cls(**vars(args))

View File

@ -0,0 +1,28 @@
from typing import Any
from pydantic import BaseModel
class Components(BaseModel):
# pipeline cls
pipeline_cls: Any = None
# Tokenizers
tokenizer: Any = None
tokenizer_2: Any = None
tokenizer_3: Any = None
# Text encoders
text_encoder: Any = None
text_encoder_2: Any = None
text_encoder_3: Any = None
# Autoencoder
vae: Any = None
# Denoiser
transformer: Any = None
unet: Any = None
# Scheduler
scheduler: Any = None

29
finetune/schemas/state.py Normal file
View File

@ -0,0 +1,29 @@
from pathlib import Path
from typing import Any, Dict, List
import torch
from pydantic import BaseModel
class State(BaseModel):
model_config = {"arbitrary_types_allowed": True}
train_frames: int
train_height: int
train_width: int
transformer_config: Dict[str, Any] = None
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
num_trainable_parameters: int = 0
overwrote_max_train_steps: bool = False
num_update_steps_per_epoch: int = 0
total_batch_size_count: int = 0
generator: torch.Generator | None = None
validation_prompts: List[str] = []
validation_images: List[Path | None] = []
validation_videos: List[Path | None] = []
using_deepspeed: bool = False

View File

@ -0,0 +1,60 @@
import argparse
import os
from pathlib import Path
import cv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--datadir",
type=str,
required=True,
help="Root directory containing videos.txt and video subdirectory",
)
return parser.parse_args()
args = parse_args()
# Create data/images directory if it doesn't exist
data_dir = Path(args.datadir)
image_dir = data_dir / "images"
image_dir.mkdir(exist_ok=True)
# Read videos.txt
videos_file = data_dir / "videos.txt"
with open(videos_file, "r") as f:
video_paths = [line.strip() for line in f.readlines() if line.strip()]
# Process each video file and collect image paths
image_paths = []
for video_rel_path in video_paths:
video_path = data_dir / video_rel_path
# Open video
cap = cv2.VideoCapture(str(video_path))
# Read first frame
ret, frame = cap.read()
if not ret:
print(f"Failed to read video: {video_path}")
continue
# Save frame as PNG with same name as video
image_name = f"images/{video_path.stem}.png"
image_path = data_dir / image_name
cv2.imwrite(str(image_path), frame)
# Release video capture
cap.release()
print(f"Extracted first frame from {video_path} to {image_path}")
image_paths.append(image_name)
# Write images.txt
images_file = data_dir / "images.txt"
with open(images_file, "w") as f:
for path in image_paths:
f.write(f"{path}\n")

19
finetune/train.py Normal file
View File

@ -0,0 +1,19 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from finetune.models.utils import get_model_cls
from finetune.schemas import Args
def main():
args = Args.parse_args()
trainer_cls = get_model_cls(args.model_name, args.training_type)
trainer = trainer_cls(args)
trainer.fit()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

70
finetune/train_ddp_i2v.sh Normal file
View File

@ -0,0 +1,70 @@
#!/usr/bin/env bash
# Prevent tokenizer parallelism issues
export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B-I2V"
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
--model_type "i2v"
--training_type "lora"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/path/to/your/output_dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt"
--video_column "videos.txt"
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10 # number of training epochs
--seed 42 # random seed
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"] # Only CogVideoX-2B supports fp16 training
)
# System Configuration
SYSTEM_ARGS=(
--num_workers 8
--pin_memory True
--nccl_timeout 1800
)
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 10 # save checkpoint every x steps
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation false # ["true", "false"]
--validation_dir "/absolute/path/to/your/validation_set"
--validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt"
--validation_images "images.txt"
--gen_fps 16
)
# Combine all arguments and launch training
accelerate launch train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \
"${TRAIN_ARGS[@]}" \
"${SYSTEM_ARGS[@]}" \
"${CHECKPOINT_ARGS[@]}" \
"${VALIDATION_ARGS[@]}"

68
finetune/train_ddp_t2v.sh Normal file
View File

@ -0,0 +1,68 @@
#!/usr/bin/env bash
# Prevent tokenizer parallelism issues
export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B"
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
--model_type "t2v"
--training_type "lora"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/absolute/path/to/your/output_dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt"
--video_column "videos.txt"
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10 # number of training epochs
--seed 42 # random seed
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"] # Only CogVideoX-2B supports fp16 training
)
# System Configuration
SYSTEM_ARGS=(
--num_workers 8
--pin_memory True
--nccl_timeout 1800
)
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 10 # save checkpoint every x steps
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation false # ["true", "false"]
--validation_dir "/absolute/path/to/your/validation_set"
--validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt"
--gen_fps 16
)
# Combine all arguments and launch training
accelerate launch train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \
"${TRAIN_ARGS[@]}" \
"${SYSTEM_ARGS[@]}" \
"${CHECKPOINT_ARGS[@]}" \
"${VALIDATION_ARGS[@]}"

View File

@ -0,0 +1,73 @@
#!/usr/bin/env bash
# Prevent tokenizer parallelism issues
export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B-I2V"
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
--model_type "i2v"
--training_type "sft"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/absolute/path/to/your/output_dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt"
--video_column "videos.txt"
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 and height, width should be multiples of 16
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10 # number of training epochs
--seed 42 # random seed
######### Please keep consistent with deepspeed config file ##########
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"] Only CogVideoX-2B supports fp16 training
########################################################################
)
# System Configuration
SYSTEM_ARGS=(
--num_workers 8
--pin_memory True
--nccl_timeout 1800
)
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 10 # save checkpoint every x steps
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation false # ["true", "false"]
--validation_dir "/absolute/path/to/validation_set"
--validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt"
--validation_images "images.txt"
--gen_fps 16
)
# Combine all arguments and launch training
accelerate launch --config_file accelerate_config.yaml train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \
"${TRAIN_ARGS[@]}" \
"${SYSTEM_ARGS[@]}" \
"${CHECKPOINT_ARGS[@]}" \
"${VALIDATION_ARGS[@]}"

View File

@ -0,0 +1,71 @@
#!/usr/bin/env bash
# Prevent tokenizer parallelism issues
export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B"
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
--model_type "t2v"
--training_type "sft"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/absolute/path/to/your/output_dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt"
--video_column "videos.txt"
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 and height, width should be multiples of 16
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10 # number of training epochs
--seed 42 # random seed
######### Please keep consistent with deepspeed config file ##########
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"] Only CogVideoX-2B supports fp16 training
########################################################################
)
# System Configuration
SYSTEM_ARGS=(
--num_workers 8
--pin_memory True
--nccl_timeout 1800
)
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 10 # save checkpoint every x steps
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation false # ["true", "false"]
--validation_dir "/absolute/path/to/validation_set"
--validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt"
--gen_fps 16
)
# Combine all arguments and launch training
accelerate launch --config_file accelerate_config.yaml train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \
"${TRAIN_ARGS[@]}" \
"${SYSTEM_ARGS[@]}" \
"${CHECKPOINT_ARGS[@]}" \
"${VALIDATION_ARGS[@]}"

811
finetune/trainer.py Normal file
View File

@ -0,0 +1,811 @@
import hashlib
import json
import logging
import math
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, List, Tuple
import diffusers
import torch
import transformers
import wandb
from accelerate.accelerator import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import (
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
gather_object,
set_seed,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines import DiffusionPipeline
from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from finetune.constants import LOG_LEVEL, LOG_NAME
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import (
load_images,
load_prompts,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_resize,
)
from finetune.schemas import Args, Components, State
from finetune.utils import (
cast_training_params,
free_memory,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_memory_statistics,
get_optimizer,
string_to_filename,
unload_model,
unwrap_model,
)
logger = get_logger(LOG_NAME, LOG_LEVEL)
_DTYPE_MAP = {
"fp32": torch.float32,
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
"bf16": torch.bfloat16,
}
class Trainer:
# If set, should be a list of components to unload (refer to `Components``)
UNLOAD_LIST: List[str] = None
def __init__(self, args: Args) -> None:
self.args = args
self.state = State(
weight_dtype=self.__get_training_dtype(),
train_frames=self.args.train_resolution[0],
train_height=self.args.train_resolution[1],
train_width=self.args.train_resolution[2],
)
self.components: Components = self.load_components()
self.accelerator: Accelerator = None
self.dataset: Dataset = None
self.data_loader: DataLoader = None
self.optimizer = None
self.lr_scheduler = None
self._init_distributed()
self._init_logging()
self._init_directories()
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
def _init_distributed(self):
logging_dir = Path(self.args.output_dir, "logs")
project_config = ProjectConfiguration(
project_dir=self.args.output_dir, logging_dir=logging_dir
)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_process_group_kwargs = InitProcessGroupKwargs(
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
)
mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
accelerator = Accelerator(
project_config=project_config,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=report_to,
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
self.accelerator = accelerator
if self.args.seed is not None:
set_seed(self.args.seed)
def _init_logging(self) -> None:
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=LOG_LEVEL,
)
if self.accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
logger.info("Initialized Trainer")
logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False)
def _init_directories(self) -> None:
if self.accelerator.is_main_process:
self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True)
def check_setting(self) -> None:
# Check for unload_list
if self.UNLOAD_LIST is None:
logger.warning(
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
)
else:
for name in self.UNLOAD_LIST:
if name not in self.components.model_fields:
raise ValueError(f"Invalid component name in unload_list: {name}")
def prepare_models(self) -> None:
logger.info("Initializing models")
if self.components.vae is not None:
if self.args.enable_slicing:
self.components.vae.enable_slicing()
if self.args.enable_tiling:
self.components.vae.enable_tiling()
self.state.transformer_config = self.components.transformer.config
def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader")
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
max_num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self,
)
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
max_num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self,
)
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
# Prepare VAE and text encoder for encoding
self.components.vae.requires_grad_(False)
self.components.text_encoder.requires_grad_(False)
self.components.vae = self.components.vae.to(
self.accelerator.device, dtype=self.state.weight_dtype
)
self.components.text_encoder = self.components.text_encoder.to(
self.accelerator.device, dtype=self.state.weight_dtype
)
# Precompute latent for video and prompt embedding
logger.info("Precomputing latent for video and prompt embedding ...")
tmp_data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_size=1,
num_workers=0,
pin_memory=self.args.pin_memory,
)
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader:
...
self.accelerator.wait_for_everyone()
logger.info("Precomputing latent for video and prompt embedding ... Done")
unload_model(self.components.vae)
unload_model(self.components.text_encoder)
free_memory()
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
shuffle=True,
)
def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters")
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = self.state.weight_dtype
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
# For LoRA, we freeze all the parameters
# For SFT, we train all the parameters in transformer model
for attr_name, component in vars(self.components).items():
if hasattr(component, "requires_grad_"):
if self.args.training_type == "sft" and attr_name == "transformer":
component.requires_grad_(True)
else:
component.requires_grad_(False)
if self.args.training_type == "lora":
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config)
# Load components needed for training to GPU (except transformer), and cast them to the specified data type
ignore_list = ["transformer"] + self.UNLOAD_LIST
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list)
if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing()
def prepare_optimizer(self) -> None:
logger.info("Initializing optimizer and lr scheduler")
# Make sure the trainable params are in float32
cast_training_params([self.components.transformer], dtype=torch.float32)
# For LoRA, we only want to train the LoRA weights
# For SFT, we want to train all the parameters
trainable_parameters = list(
filter(lambda p: p.requires_grad, self.components.transformer.parameters())
)
transformer_parameters_with_lr = {
"params": trainable_parameters,
"lr": self.args.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr]
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
use_deepspeed_opt = (
self.accelerator.state.deepspeed_plugin is not None
and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(
params_to_optimize=params_to_optimize,
optimizer_name=self.args.optimizer,
learning_rate=self.args.learning_rate,
beta1=self.args.beta1,
beta2=self.args.beta2,
beta3=self.args.beta3,
epsilon=self.args.epsilon,
weight_decay=self.args.weight_decay,
use_deepspeed=use_deepspeed_opt,
)
num_update_steps_per_epoch = math.ceil(
len(self.data_loader) / self.args.gradient_accumulation_steps
)
if self.args.train_steps is None:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
self.state.overwrote_max_train_steps = True
use_deepspeed_lr_scheduler = (
self.accelerator.state.deepspeed_plugin is not None
and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config
)
total_training_steps = self.args.train_steps * self.accelerator.num_processes
num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes
if use_deepspeed_lr_scheduler:
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=self.args.lr_scheduler,
optimizer=optimizer,
total_num_steps=total_training_steps,
num_warmup_steps=num_warmup_steps,
)
else:
lr_scheduler = get_scheduler(
name=self.args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_training_steps,
num_cycles=self.args.lr_num_cycles,
power=self.args.lr_power,
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def prepare_for_training(self) -> None:
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = (
self.accelerator.prepare(
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
)
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(self.data_loader) / self.args.gradient_accumulation_steps
)
if self.state.overwrote_max_train_steps:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
def prepare_for_validation(self):
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
if self.args.validation_images is not None:
validation_images = load_images(self.args.validation_dir / self.args.validation_images)
else:
validation_images = [None] * len(validation_prompts)
if self.args.validation_videos is not None:
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos)
else:
validation_videos = [None] * len(validation_prompts)
self.state.validation_prompts = validation_prompts
self.state.validation_images = validation_images
self.state.validation_videos = validation_videos
def prepare_trackers(self) -> None:
logger.info("Initializing trackers")
tracker_name = self.args.tracker_name or "finetrainers-experiment"
self.accelerator.init_trackers(tracker_name, config=self.args.model_dump())
def train(self) -> None:
logger.info("Starting training")
memory_statistics = get_memory_statistics()
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
self.state.total_batch_size_count = (
self.args.batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
info = {
"trainable parameters": self.state.num_trainable_parameters,
"total samples": len(self.dataset),
"train epochs": self.args.train_epochs,
"train steps": self.args.train_steps,
"batches per device": self.args.batch_size,
"total batches observed per epoch": len(self.data_loader),
"train batch size total count": self.state.total_batch_size_count,
"gradient accumulation steps": self.args.gradient_accumulation_steps,
}
logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
global_step = 0
first_epoch = 0
initial_global_step = 0
# Potentially load in the weights and states from a previous save
(
resume_from_checkpoint_path,
initial_global_step,
global_step,
first_epoch,
) = get_latest_ckpt_path_to_resume_from(
resume_from_checkpoint=self.args.resume_from_checkpoint,
num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
)
if resume_from_checkpoint_path is not None:
self.accelerator.load_state(resume_from_checkpoint_path)
progress_bar = tqdm(
range(0, self.args.train_steps),
initial=initial_global_step,
desc="Training steps",
disable=not self.accelerator.is_local_main_process,
)
accelerator = self.accelerator
generator = torch.Generator(device=accelerator.device)
if self.args.seed is not None:
generator = generator.manual_seed(self.args.seed)
self.state.generator = generator
free_memory()
for epoch in range(first_epoch, self.args.train_epochs):
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
self.components.transformer.train()
models_to_accumulate = [self.components.transformer]
for step, batch in enumerate(self.data_loader):
logger.debug(f"Starting step {step + 1}")
logs = {}
with accelerator.accumulate(models_to_accumulate):
# These weighting schemes use a uniform timestep sampling and instead post-weight the loss
loss = self.compute_loss(batch)
accelerator.backward(loss)
if accelerator.sync_gradients:
if accelerator.distributed_type == DistributedType.DEEPSPEED:
grad_norm = self.components.transformer.get_global_grad_norm()
# In some cases the grad norm may not return a float
if torch.is_tensor(grad_norm):
grad_norm = grad_norm.item()
else:
grad_norm = accelerator.clip_grad_norm_(
self.components.transformer.parameters(), self.args.max_grad_norm
)
if torch.is_tensor(grad_norm):
grad_norm = grad_norm.item()
logs["grad_norm"] = grad_norm
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
self.__maybe_save_checkpoint(global_step)
logs["loss"] = loss.detach().item()
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
# Maybe run validation
should_run_validation = (
self.args.do_validation and global_step % self.args.validation_steps == 0
)
if should_run_validation:
del loss
free_memory()
self.validate(global_step)
accelerator.log(logs, step=global_step)
if global_step >= self.args.train_steps:
break
memory_statistics = get_memory_statistics()
logger.info(
f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
)
accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True)
if self.args.do_validation:
free_memory()
self.validate(global_step)
del self.components
free_memory()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
accelerator.end_training()
def validate(self, step: int) -> None:
logger.info("Starting validation")
accelerator = self.accelerator
num_validation_samples = len(self.state.validation_prompts)
if num_validation_samples == 0:
logger.warning("No validation samples found. Skipping validation.")
return
self.components.transformer.eval()
torch.set_grad_enabled(False)
memory_statistics = get_memory_statistics()
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
##### Initialize pipeline #####
pipe = self.initialize_pipeline()
if self.state.using_deepspeed:
# Can't using model_cpu_offload in deepspeed,
# so we need to move all components in pipe to device
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.__move_components_to_device(
dtype=self.state.weight_dtype, ignore_list=["transformer"]
)
else:
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device)
# Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype)
#################################
all_processes_artifacts = []
for i in range(num_validation_samples):
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue
prompt = self.state.validation_prompts[i]
image = self.state.validation_images[i]
video = self.state.validation_videos[i]
if image is not None:
image = preprocess_image_with_resize(
image, self.state.train_height, self.state.train_width
)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
image = Image.fromarray(image)
if video is not None:
video = preprocess_video_with_resize(
video, self.state.train_frames, self.state.train_height, self.state.train_width
)
# Convert video tensor (F, C, H, W) to list of PIL images
video = video.round().clamp(0, 255).to(torch.uint8)
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.validation_step(
{"prompt": prompt, "image": image, "video": video}, pipe
)
if (
self.state.using_deepspeed
and self.accelerator.deepspeed_plugin.zero_stage == 3
and not accelerator.is_main_process
):
continue
prompt_filename = string_to_filename(prompt)[:25]
# Calculate hash of reversed prompt as a unique identifier
reversed_prompt = prompt[::-1]
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
artifacts = {
"image": {"type": "image", "value": image},
"video": {"type": "video", "value": video},
}
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
artifacts.update(
{f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}
)
logger.debug(
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
main_process_only=False,
)
for key, value in list(artifacts.items()):
artifact_type = value["type"]
artifact_value = value["value"]
if artifact_type not in ["image", "video"] or artifact_value is None:
continue
extension = "png" if artifact_type == "image" else "mp4"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
validation_path = self.args.output_dir / "validation_res"
validation_path.mkdir(parents=True, exist_ok=True)
filename = str(validation_path / filename)
if artifact_type == "image":
logger.debug(f"Saving image to {filename}")
artifact_value.save(filename)
artifact_value = wandb.Image(filename)
elif artifact_type == "video":
logger.debug(f"Saving video to {filename}")
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
artifact_value = wandb.Video(filename, caption=prompt)
all_processes_artifacts.append(artifact_value)
all_artifacts = gather_object(all_processes_artifacts)
if accelerator.is_main_process:
tracker_key = "validation"
for tracker in accelerator.trackers:
if tracker.name == "wandb":
image_artifacts = [
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)
]
video_artifacts = [
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)
]
tracker.log(
{
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
},
step=step,
)
########## Clean up ##########
if self.state.using_deepspeed:
del pipe
# Unload models except those needed for training
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
else:
pipe.remove_all_hooks()
del pipe
# Load models except those not needed for training
self.__move_components_to_device(
dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST
)
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
# Change trainable weights back to fp32 to keep with dtype after prepare the model
cast_training_params([self.components.transformer], dtype=torch.float32)
free_memory()
accelerator.wait_for_everyone()
################################
memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
torch.set_grad_enabled(True)
self.components.transformer.train()
def fit(self):
self.check_setting()
self.prepare_models()
self.prepare_dataset()
self.prepare_trainable_parameters()
self.prepare_optimizer()
self.prepare_for_training()
if self.args.do_validation:
self.prepare_for_validation()
self.prepare_trackers()
self.train()
def collate_fn(self, examples: List[Dict[str, Any]]):
raise NotImplementedError
def load_components(self) -> Components:
raise NotImplementedError
def initialize_pipeline(self) -> DiffusionPipeline:
raise NotImplementedError
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W], where B = 1
# shape of output video: [B, C', F', H', W'], where B = 1
raise NotImplementedError
def encode_text(self, text: str) -> torch.Tensor:
# shape of output text: [batch size, sequence length, embedding dimension]
raise NotImplementedError
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
raise NotImplementedError
def __get_training_dtype(self) -> torch.dtype:
if self.args.mixed_precision == "no":
return _DTYPE_MAP["fp32"]
elif self.args.mixed_precision == "fp16":
return _DTYPE_MAP["fp16"]
elif self.args.mixed_precision == "bf16":
return _DTYPE_MAP["bf16"]
else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
ignore_list = set(ignore_list)
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
if name not in ignore_list:
setattr(
self.components, name, component.to(self.accelerator.device, dtype=dtype)
)
def __move_components_to_cpu(self, unload_list: List[str] = []):
unload_list = set(unload_list)
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
if name in unload_list:
setattr(self.components, name, component.to("cpu"))
def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if self.accelerator.is_main_process:
transformer_lora_layers_to_save = None
for model in models:
if isinstance(
unwrap_model(self.accelerator, model),
type(unwrap_model(self.accelerator, self.components.transformer)),
):
model = unwrap_model(self.accelerator, model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"Unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
self.components.pipeline_cls.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
if not self.accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(
unwrap_model(self.accelerator, model),
type(unwrap_model(self.accelerator, self.components.transformer)),
):
transformer_ = unwrap_model(self.accelerator, model)
else:
raise ValueError(
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
)
else:
transformer_ = unwrap_model(
self.accelerator, self.components.transformer
).__class__.from_pretrained(self.args.model_path, subfolder="transformer")
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("transformer.")
}
incompatible_keys = set_peft_model_state_dict(
transformer_, transformer_state_dict, adapter_name="default"
)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook)
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
if (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
or self.accelerator.is_main_process
):
if must_save or global_step % self.args.checkpointing_steps == 0:
# for training
save_path = get_intermediate_ckpt_path(
checkpointing_limit=self.args.checkpointing_limit,
step=global_step,
output_dir=self.args.output_dir,
)
self.accelerator.save_state(save_path, safe_serialization=True)

View File

@ -0,0 +1,5 @@
from .checkpointing import *
from .file_utils import *
from .memory_utils import *
from .optimizer_utils import *
from .torch_utils import *

View File

@ -0,0 +1,57 @@
import os
from pathlib import Path
from typing import Tuple
from accelerate.logging import get_logger
from finetune.constants import LOG_LEVEL, LOG_NAME
from ..utils.file_utils import delete_files, find_files
logger = get_logger(LOG_NAME, LOG_LEVEL)
def get_latest_ckpt_path_to_resume_from(
resume_from_checkpoint: str | None, num_update_steps_per_epoch: int
) -> Tuple[str | None, int, int, int]:
if resume_from_checkpoint is None:
initial_global_step = 0
global_step = 0
first_epoch = 0
resume_from_checkpoint_path = None
else:
resume_from_checkpoint_path = Path(resume_from_checkpoint)
if not resume_from_checkpoint_path.exists():
logger.info(
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
)
initial_global_step = 0
global_step = 0
first_epoch = 0
resume_from_checkpoint_path = None
else:
logger.info(f"Resuming from checkpoint {resume_from_checkpoint}")
global_step = int(resume_from_checkpoint_path.name.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
# before saving state, check if this save would set us over the `checkpointing_limit`
if checkpointing_limit is not None:
checkpoints = find_files(output_dir, prefix="checkpoint")
# before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= checkpointing_limit:
num_to_remove = len(checkpoints) - checkpointing_limit + 1
checkpoints_to_remove = checkpoints[0:num_to_remove]
delete_files(checkpoints_to_remove)
logger.info(f"Checkpointing at step {step}")
save_path = os.path.join(output_dir, f"checkpoint-{step}")
logger.info(f"Saving state to {save_path}")
return save_path

View File

@ -0,0 +1,48 @@
import logging
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Union
from accelerate.logging import get_logger
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)
def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
if not isinstance(dir, Path):
dir = Path(dir)
if not dir.exists():
return []
checkpoints = os.listdir(dir.as_posix())
checkpoints = [c for c in checkpoints if c.startswith(prefix)]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
checkpoints = [dir / c for c in checkpoints]
return checkpoints
def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
if not isinstance(dirs, list):
dirs = [dirs]
dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
logger.info(f"Deleting files: {dirs}")
for dir in dirs:
if not dir.exists():
continue
shutil.rmtree(dir, ignore_errors=True)
def string_to_filename(s: str) -> str:
return (
s.replace(" ", "-")
.replace("/", "-")
.replace(":", "-")
.replace(".", "-")
.replace(",", "-")
.replace(";", "-")
.replace("!", "-")
.replace("?", "-")
)

View File

@ -0,0 +1,66 @@
import gc
from typing import Any, Dict, Union
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
memory_allocated = None
memory_reserved = None
max_memory_allocated = None
max_memory_reserved = None
if torch.cuda.is_available():
device = torch.cuda.current_device()
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
max_memory_allocated = torch.cuda.max_memory_allocated(device)
max_memory_reserved = torch.cuda.max_memory_reserved(device)
elif torch.mps.is_available():
memory_allocated = torch.mps.current_allocated_memory()
else:
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
return {
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
}
def bytes_to_gigabytes(x: int) -> float:
if x is not None:
return x / 1024**3
def free_memory() -> None:
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# TODO(aryan): handle non-cuda devices
def unload_model(model):
model.to("cpu")
def make_contiguous(
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if isinstance(x, torch.Tensor):
return x.contiguous()
elif isinstance(x, dict):
return {k: make_contiguous(v) for k, v in x.items()}
else:
return x

View File

@ -0,0 +1,191 @@
import inspect
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)
def get_optimizer(
params_to_optimize,
optimizer_name: str = "adam",
learning_rate: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.95,
beta3: float = 0.98,
epsilon: float = 1e-8,
weight_decay: float = 1e-4,
prodigy_decouple: bool = False,
prodigy_use_bias_correction: bool = False,
prodigy_safeguard_warmup: bool = False,
use_8bit: bool = False,
use_4bit: bool = False,
use_torchao: bool = False,
use_deepspeed: bool = False,
use_cpu_offload_optimizer: bool = False,
offload_gradients: bool = False,
) -> torch.optim.Optimizer:
optimizer_name = optimizer_name.lower()
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=learning_rate,
betas=(beta1, beta2),
eps=epsilon,
weight_decay=weight_decay,
)
if use_8bit and use_4bit:
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
try:
import torchao
torchao.__version__
except ImportError:
raise ImportError(
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
)
if not use_torchao and use_4bit:
raise ValueError("4-bit Optimizers are only supported with torchao.")
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
if optimizer_name not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
)
optimizer_name = "adamw"
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
raise ValueError(
"`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
)
if use_8bit:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if optimizer_name == "adamw":
if use_torchao:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
optimizer_class = (
AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
)
else:
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "adam":
if use_torchao:
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
else:
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError(
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
)
optimizer_class = prodigyopt.Prodigy
if learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
init_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"beta3": beta3,
"eps": epsilon,
"weight_decay": weight_decay,
"decouple": prodigy_decouple,
"use_bias_correction": prodigy_use_bias_correction,
"safeguard_warmup": prodigy_safeguard_warmup,
}
elif optimizer_name == "came":
try:
import came_pytorch
except ImportError:
raise ImportError(
"To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
)
optimizer_class = came_pytorch.CAME
init_kwargs = {
"lr": learning_rate,
"eps": (1e-30, 1e-16),
"betas": (beta1, beta2, beta3),
"weight_decay": weight_decay,
}
if use_cpu_offload_optimizer:
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
init_kwargs.update({"fused": True})
optimizer = CPUOffloadOptimizer(
params_to_optimize,
optimizer_class=optimizer_class,
offload_gradients=offload_gradients,
**init_kwargs,
)
else:
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
return optimizer
def gradient_norm(parameters):
norm = 0
for param in parameters:
if param.grad is None:
continue
local_norm = param.grad.detach().data.norm(2)
norm += local_norm.item() ** 2
norm = norm**0.5
return norm
def max_gradient(parameters):
max_grad_value = float("-inf")
for param in parameters:
if param.grad is None:
continue
local_max_grad = param.grad.detach().data.abs().max()
max_grad_value = max(max_grad_value, local_max_grad.item())
return max_grad_value

View File

@ -0,0 +1,52 @@
from typing import Dict, List, Optional, Union
import torch
from accelerate import Accelerator
from diffusers.utils.torch_utils import is_compiled_module
def unwrap_model(accelerator: Accelerator, model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def align_device_and_dtype(
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if isinstance(x, torch.Tensor):
if device is not None:
x = x.to(device)
if dtype is not None:
x = x.to(dtype)
elif isinstance(x, dict):
if device is not None:
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
if dtype is not None:
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
return x
def expand_tensor_to_dims(tensor, ndim):
while len(tensor.shape) < ndim:
tensor = tensor.unsqueeze(-1)
return tensor
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
"""
Casts the training parameters of the model to the specified data type.
Args:
model: The PyTorch model whose parameters will be cast.
dtype: The data type to which the model parameters will be cast.
"""
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)

View File

@ -3,40 +3,60 @@ This script demonstrates how to generate a video using the CogVideoX model with
The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v),
and video-to-video (v2v), depending on the input data and different weight.
- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
- image-to-video: THUDM/CogVideoX-5b-I2V
- text-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
- video-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
- image-to-video: THUDM/CogVideoX-5b-I2V or THUDM/CogVideoX1.5-5b-I2V
Running the Script:
To run the script, use the following command with appropriate arguments:
```bash
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v"
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v"
```
You can change `pipe.enable_sequential_cpu_offload()` to `pipe.enable_model_cpu_offload()` to speed up inference, but this will use more GPU memory
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
"""
import argparse
from typing import Literal
import logging
from typing import Literal, Optional
import torch
from diffusers import (
CogVideoXPipeline,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
)
from diffusers.utils import export_to_video, load_image, load_video
logging.basicConfig(level=logging.INFO)
# Recommended resolution for each model (width, height)
RESOLUTION_MAP = {
# cogvideox1.5-*
"cogvideox1.5-5b-i2v": (768, 1360),
"cogvideox1.5-5b": (768, 1360),
# cogvideox-*
"cogvideox-5b-i2v": (480, 720),
"cogvideox-5b": (480, 720),
"cogvideox-2b": (480, 720),
}
def generate_video(
prompt: str,
model_path: str,
lora_path: str = None,
lora_rank: int = 128,
num_frames: int = 81,
width: Optional[int] = None,
height: Optional[int] = None,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
num_inference_steps: int = 50,
@ -45,6 +65,7 @@ def generate_video(
dtype: torch.dtype = torch.bfloat16,
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
seed: int = 42,
fps: int = 16,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
@ -56,11 +77,15 @@ def generate_video(
- lora_rank (int): The rank of the LoRA weights.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- num_frames (int): Number of frames to generate. CogVideoX1.0 generates 49 frames for 6 seconds at 8 fps, while CogVideoX1.5 produces either 81 or 161 frames, corresponding to 5 seconds or 10 seconds at 16 fps.
- width (int): The width of the generated video, applicable only for CogVideoX1.5-5B-I2V
- height (int): The height of the generated video, applicable only for CogVideoX1.5-5B-I2V
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
- seed (int): The seed for reproducibility.
- fps (int): The frames per second for the generated video.
"""
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
@ -70,6 +95,26 @@ def generate_video(
image = None
video = None
model_name = model_path.split("/")[-1].lower()
desired_resolution = RESOLUTION_MAP[model_name]
if width is None or height is None:
height, width = desired_resolution
logging.info(
f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m"
)
elif (height, width) != desired_resolution:
if generate_type == "i2v":
# For i2v models, use user-defined width and height
logging.warning(
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m"
)
else:
# Otherwise, use the recommended width and height
logging.warning(
f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m"
)
height, width = desired_resolution
if generate_type == "i2v":
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
image = load_image(image=image_or_video_path)
@ -81,8 +126,10 @@ def generate_video(
# If you're using with lora, add this code
if lora_path:
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
pipe.fuse_lora(lora_scale=1 / lora_rank)
pipe.load_lora_weights(
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
)
pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
# 2. Set Scheduler.
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
@ -90,61 +137,71 @@ def generate_video(
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
# 3. Enable CPU offload for the model.
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
# and enable to("cuda")
# pipe.to("cuda")
# pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# 4. Generate the video frames based on the prompt.
# `num_frames` is the Number of frames to generate.
# This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
if generate_type == "i2v":
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
image=image, # The path of the image to be used as the background of the video
image=image,
# The path of the image, the resolution of video will be the same as the image for CogVideoX1.5-5B-I2V, otherwise it will be 720 * 480
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
num_inference_steps=num_inference_steps, # Number of inference steps
num_frames=49, # Number of frames to generatechanged to 49 for diffusers version `0.30.3` and after.
use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False
num_frames=num_frames, # Number of frames to generate
use_dynamic_cfg=True, # This id used for DPM scheduler, for DDIM scheduler, it should be False
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
).frames[0]
elif generate_type == "t2v":
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
num_frames=49,
num_frames=num_frames,
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
).frames[0]
else:
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
video=video, # The path of the video to be used as the background of the video
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
# num_frames=49,
num_frames=num_frames,
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
).frames[0]
# 5. Export the generated frames to a video file. fps must be 8 for original video.
export_to_video(video_generate, output_path, fps=8)
export_to_video(video_generate, output_path, fps=fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt using CogVideoX"
)
parser.add_argument(
"--prompt", type=str, required=True, help="The description of the video to be generated"
)
parser.add_argument(
"--image_or_video_path",
type=str,
@ -152,23 +209,43 @@ if __name__ == "__main__":
help="The path of the image to be used as the background of the video",
)
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
"--model_path",
type=str,
default="THUDM/CogVideoX1.5-5B",
help="Path of the pre-trained model use",
)
parser.add_argument(
"--lora_path", type=str, default=None, help="The path of the LoRA weights to be used"
)
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
)
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
"--output_path", type=str, default="./output.mp4", help="The path save generated video"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
"--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance"
)
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument(
"--num_frames", type=int, default=81, help="Number of steps for the inference process"
)
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
parser.add_argument(
"--height", type=int, default=None, help="The height of the generated video"
)
parser.add_argument(
"--fps", type=int, default=16, help="The frames per second for the generated video"
)
parser.add_argument(
"--num_videos_per_prompt",
type=int,
default=1,
help="Number of videos to generate per prompt",
)
parser.add_argument(
"--generate_type", type=str, default="t2v", help="The type of video generation"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation"
)
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
@ -180,6 +257,9 @@ if __name__ == "__main__":
lora_path=args.lora_path,
lora_rank=args.lora_rank,
output_path=args.output_path,
num_frames=args.num_frames,
width=args.width,
height=args.height,
image_or_video_path=args.image_or_video_path,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
@ -187,4 +267,5 @@ if __name__ == "__main__":
dtype=dtype,
generate_type=args.generate_type,
seed=args.seed,
fps=args.fps,
)

View File

@ -3,7 +3,7 @@ This script demonstrates how to generate a video from a text prompt using CogVid
Note:
Must install the `torchao``torch`,`diffusers`,`accelerate` library FROM SOURCE to use the quantization feature.
Must install the `torchao``torch` library FROM SOURCE to use the quantization feature.
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
ALL quantization schemes must use with NVIDIA GPUs.
@ -19,7 +19,12 @@ import argparse
import os
import torch
import torch._dynamo
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXTransformer3DModel,
CogVideoXPipeline,
CogVideoXDPMScheduler,
)
from diffusers.utils import export_to_video
from transformers import T5EncoderModel
from torchao.quantization import quantize_, int8_weight_only
@ -51,6 +56,9 @@ def generate_video(
num_videos_per_prompt: int = 1,
quantization_scheme: str = "fp8",
dtype: torch.dtype = torch.bfloat16,
num_frames: int = 81,
fps: int = 8,
seed: int = 42,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
@ -65,10 +73,13 @@ def generate_video(
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
"""
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder", torch_dtype=dtype
)
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=dtype
)
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
@ -79,55 +90,59 @@ def generate_video(
vae=vae,
torch_dtype=dtype,
)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Using with compile will run faster. First time infer will cost ~30min to compile.
# pipe.transformer.to(memory_format=torch.channels_last)
# for FP8 should remove pipe.enable_model_cpu_offload()
pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe.enable_model_cpu_offload()
# This is not for FP8 and INT8 and should remove this line
# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
video = pipe(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
num_frames=49,
num_frames=num_frames,
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(42),
generator=torch.Generator(device="cuda").manual_seed(seed),
).frames[0]
export_to_video(video, output_path, fps=8)
export_to_video(video, output_path, fps=fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt using CogVideoX"
)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
"--prompt", type=str, required=True, help="The description of the video to be generated"
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16', 'bfloat16')"
"--output_path", type=str, default="./output.mp4", help="Path to save generated video"
)
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
)
parser.add_argument(
"--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')"
)
parser.add_argument(
"--quantization_scheme",
type=str,
default="bf16",
default="fp8",
choices=["int8", "fp8"],
help="The quantization scheme to use (int8, fp8)",
help="Quantization scheme",
)
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
@ -140,4 +155,7 @@ if __name__ == "__main__":
num_videos_per_prompt=args.num_videos_per_prompt,
quantization_scheme=args.quantization_scheme,
dtype=dtype,
num_frames=args.num_frames,
fps=args.fps,
seed=args.seed,
)

View File

@ -104,18 +104,34 @@ def save_video(tensor, output_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
parser.add_argument(
"--model_path", type=str, required=True, help="The path to the CogVideoX model"
)
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
parser.add_argument(
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
"--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
"--output_path", type=str, default=".", help="The path to save the output file"
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
"--mode",
type=str,
choices=["encode", "decode", "both"],
required=True,
help="Mode: encode, decode, or both",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="The data type for computation (e.g., 'float16' or 'bfloat16')",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="The device to use for computation (e.g., 'cuda' or 'cpu')",
)
args = parser.parse_args()
@ -126,15 +142,21 @@ if __name__ == "__main__":
assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt")
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
print(
f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt"
)
elif args.mode == "decode":
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
save_video(decoded_output, args.output_path)
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
print(
f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4"
)
elif args.mode == "both":
assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt")
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
decoded_output = decode_video(
args.model_path, args.output_path + "/encoded.pt", dtype, device
)
save_video(decoded_output, args.output_path)

View File

@ -144,7 +144,9 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
parser.add_argument(
"--retry_times", type=int, default=3, help="Number of times to retry the conversion"
)
parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)")
parser.add_argument("--image_path", type=str, default=None, help="Path to the image file")
args = parser.parse_args()

519
inference/ddim_inversion.py Normal file
View File

@ -0,0 +1,519 @@
"""
This script performs DDIM inversion for video frames using a pre-trained model and generates
a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to
process video frames, apply the DDIM inverse scheduler, and produce an output video.
**Please notice that this script is based on the CogVideoX 5B model, and would not generate
a good result for 2B variants.**
Usage:
python ddim_inversion.py
--model-path /path/to/model
--prompt "a prompt"
--video-path /path/to/video.mp4
--output-path /path/to/output
For more details about the cli arguments, please run `python ddim_inversion.py --help`.
Author:
LittleNyima <littlenyima[at]163[dot]com>
"""
import argparse
import math
import os
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
from diffusers.models.embeddings import apply_rotary_emb
from diffusers.models.transformers.cogvideox_transformer_3d import (
CogVideoXBlock,
CogVideoXTransformer3DModel,
)
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
from diffusers.utils import export_to_video
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error.
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort: skip
class DDIMInversionArguments(TypedDict):
model_path: str
prompt: str
video_path: str
output_path: str
guidance_scale: float
num_inference_steps: int
skip_frames_start: int
skip_frames_end: int
frame_sample_step: Optional[int]
max_num_frames: int
width: int
height: int
fps: int
dtype: torch.dtype
seed: int
device: torch.device
def get_args() -> DDIMInversionArguments:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path", type=str, required=True, help="Path of the pretrained model"
)
parser.add_argument(
"--prompt", type=str, required=True, help="Prompt for the direct sample procedure"
)
parser.add_argument(
"--video_path", type=str, required=True, help="Path of the video for inversion"
)
parser.add_argument(
"--output_path", type=str, default="output", help="Path of the output videos"
)
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
)
parser.add_argument(
"--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start"
)
parser.add_argument(
"--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end"
)
parser.add_argument(
"--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames"
)
parser.add_argument(
"--max_num_frames", type=int, default=81, help="Max number of sampled frames"
)
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
parser.add_argument(
"--height", type=int, default=480, help="Resized height of the video frames"
)
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
parser.add_argument(
"--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model"
)
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
parser.add_argument(
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference"
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
args.device = torch.device(args.device)
return DDIMInversionArguments(**vars(args))
class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
def __init__(self):
super().__init__()
def calculate_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn: Attention,
batch_size: int,
image_seq_length: int,
text_seq_length: int,
attention_mask: Optional[torch.Tensor],
image_rotary_emb: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(
query[:, :, text_seq_length:], image_rotary_emb
)
if not attn.is_cross_attention:
if key.size(2) == query.size(2): # Attention for reference hidden states
key[:, :, text_seq_length:] = apply_rotary_emb(
key[:, :, text_seq_length:], image_rotary_emb
)
else: # RoPE should be applied to each group of image tokens
key[:, :, text_seq_length : text_seq_length + image_seq_length] = (
apply_rotary_emb(
key[:, :, text_seq_length : text_seq_length + image_seq_length],
image_rotary_emb,
)
)
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
image_seq_length = hidden_states.size(1)
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query, query_reference = query.chunk(2)
key, key_reference = key.chunk(2)
value, value_reference = value.chunk(2)
batch_size = batch_size // 2
hidden_states, encoder_hidden_states = self.calculate_attention(
query=query,
key=torch.cat((key, key_reference), dim=1),
value=torch.cat((value, value_reference), dim=1),
attn=attn,
batch_size=batch_size,
image_seq_length=image_seq_length,
text_seq_length=text_seq_length,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states_reference, encoder_hidden_states_reference = self.calculate_attention(
query=query_reference,
key=key_reference,
value=value_reference,
attn=attn,
batch_size=batch_size,
image_seq_length=image_seq_length,
text_seq_length=text_seq_length,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
return (
torch.cat((hidden_states, hidden_states_reference)),
torch.cat((encoder_hidden_states, encoder_hidden_states_reference)),
)
class OverrideAttnProcessors:
def __init__(self, transformer: CogVideoXTransformer3DModel):
self.transformer = transformer
self.original_processors = {}
def __enter__(self):
for block in self.transformer.transformer_blocks:
block = cast(CogVideoXBlock, block)
self.original_processors[id(block)] = block.attn1.get_processor()
block.attn1.set_processor(CogVideoXAttnProcessor2_0ForDDIMInversion())
def __exit__(self, _0, _1, _2):
for block in self.transformer.transformer_blocks:
block = cast(CogVideoXBlock, block)
block.attn1.set_processor(self.original_processors[id(block)])
def get_video_frames(
video_path: str,
width: int,
height: int,
skip_frames_start: int,
skip_frames_end: int,
max_num_frames: int,
frame_sample_step: Optional[int],
) -> torch.FloatTensor:
with decord.bridge.use_torch():
video_reader = decord.VideoReader(uri=video_path, width=width, height=height)
video_num_frames = len(video_reader)
start_frame = min(skip_frames_start, video_num_frames)
end_frame = max(0, video_num_frames - skip_frames_end)
if end_frame <= start_frame:
indices = [start_frame]
elif end_frame - start_frame <= max_num_frames:
indices = list(range(start_frame, end_frame))
else:
step = frame_sample_step or (end_frame - start_frame) // max_num_frames
indices = list(range(start_frame, end_frame, step))
frames = video_reader.get_batch(indices=indices)
frames = frames[:max_num_frames].float() # ensure that we don't go over the limit
# Choose first (4k + 1) frames as this is how many is required by the VAE
selected_num_frames = frames.size(0)
remainder = (3 + selected_num_frames) % 4
if remainder != 0:
frames = frames[:-remainder]
assert frames.size(0) % 4 == 1
# Normalize the frames
transform = T.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
frames = torch.stack(tuple(map(transform, frames)), dim=0)
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
def encode_video_frames(
vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor
) -> torch.FloatTensor:
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
return latent_dist * vae.config.scaling_factor
def export_latents_to_video(
pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int
):
video = pipeline.decode_latents(latents)
frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil")
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
# Modified from CogVideoXPipeline.__call__
def sample(
pipeline: CogVideoXPipeline,
latents: torch.FloatTensor,
scheduler: Union[DDIMInverseScheduler, CogVideoXDDIMScheduler],
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_inference_steps: int = 50,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
reference_latents: torch.FloatTensor = None,
) -> torch.FloatTensor:
pipeline._guidance_scale = guidance_scale
pipeline._attention_kwargs = attention_kwargs
pipeline._interrupt = False
device = pipeline._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
if reference_latents is not None:
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
pipeline._num_timesteps = len(timesteps)
# 5. Prepare latents.
latents = latents.to(device=device) * scheduler.init_noise_sigma
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
if isinstance(
scheduler, DDIMInverseScheduler
): # Inverse scheduler does not accept extra kwargs
extra_step_kwargs = {}
# 7. Create rotary embeds if required
image_rotary_emb = (
pipeline._prepare_rotary_positional_embeddings(
height=latents.size(3) * pipeline.vae_scale_factor_spatial,
width=latents.size(4) * pipeline.vae_scale_factor_spatial,
num_frames=latents.size(1),
device=device,
)
if pipeline.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
trajectory = torch.zeros_like(latents).unsqueeze(0).repeat(len(timesteps), 1, 1, 1, 1, 1)
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if pipeline.interrupt:
continue
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if reference_latents is not None:
reference = reference_latents[i]
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
latent_model_input = torch.cat([latent_model_input, reference], dim=0)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = pipeline.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if reference_latents is not None: # Recover the original batch size
noise_pred, _ = noise_pred.chunk(2)
# perform guidance
if use_dynamic_cfg:
pipeline._guidance_scale = 1 + guidance_scale * (
(
1
- math.cos(
math.pi
* ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0
)
)
/ 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the noisy sample x_t-1 -> x_t
latents = scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.to(prompt_embeds.dtype)
trajectory[i] = latents
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0
):
progress_bar.update()
# Offload all models
pipeline.maybe_free_model_hooks()
return trajectory
@torch.no_grad()
def ddim_inversion(
model_path: str,
prompt: str,
video_path: str,
output_path: str,
guidance_scale: float,
num_inference_steps: int,
skip_frames_start: int,
skip_frames_end: int,
frame_sample_step: Optional[int],
max_num_frames: int,
width: int,
height: int,
fps: int,
dtype: torch.dtype,
seed: int,
device: torch.device,
):
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(
model_path, torch_dtype=dtype
).to(device=device)
if not pipeline.transformer.config.use_rotary_positional_embeddings:
raise NotImplementedError("This script supports CogVideoX 5B model only.")
video_frames = get_video_frames(
video_path=video_path,
width=width,
height=height,
skip_frames_start=skip_frames_start,
skip_frames_end=skip_frames_end,
max_num_frames=max_num_frames,
frame_sample_step=frame_sample_step,
).to(device=device)
video_latents = encode_video_frames(vae=pipeline.vae, video_frames=video_frames)
inverse_scheduler = DDIMInverseScheduler(**pipeline.scheduler.config)
inverse_latents = sample(
pipeline=pipeline,
latents=video_latents,
scheduler=inverse_scheduler,
prompt="",
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed),
)
with OverrideAttnProcessors(transformer=pipeline.transformer):
recon_latents = sample(
pipeline=pipeline,
latents=torch.randn_like(video_latents),
scheduler=pipeline.scheduler,
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed),
reference_latents=reversed(inverse_latents),
)
filename, _ = os.path.splitext(os.path.basename(video_path))
inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4")
recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4")
export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)
export_latents_to_video(pipeline, recon_latents[-1], recon_video_path, fps)
if __name__ == "__main__":
arguments = get_args()
ddim_inversion(**arguments)

View File

@ -41,7 +41,5 @@ pip install -r requirements.txt
## Running the code
```bash
python gradio_web_demo.py
python app.py
```

View File

@ -30,7 +30,7 @@ from datetime import datetime, timedelta
from diffusers.image_processor import VaeImageProcessor
from openai import OpenAI
import moviepy.editor as mp
from moviepy import VideoFileClip
import utils
from rife_model import load_rife_model, rife_inference_with_latents
from huggingface_hub import hf_hub_download, snapshot_download
@ -39,11 +39,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = "THUDM/CogVideoX-5b"
hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
hf_hub_download(
repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran"
)
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
MODEL,
transformer=pipe.transformer,
@ -271,9 +275,9 @@ def infer(
def convert_to_gif(video_path):
clip = mp.VideoFileClip(video_path)
clip = clip.set_fps(8)
clip = clip.resize(height=240)
clip = VideoFileClip(video_path)
clip = clip.with_fps(8)
clip = clip.resized(height=240)
gif_path = video_path.replace(".mp4", ".gif")
clip.write_gif(gif_path, fps=8)
return gif_path
@ -296,8 +300,16 @@ def delete_old_files():
threading.Thread(target=delete_old_files, daemon=True).start()
examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]]
examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]]
examples_videos = [
["example_videos/horse.mp4"],
["example_videos/kitten.mp4"],
["example_videos/train_running.mp4"],
]
examples_images = [
["example_images/beach.png"],
["example_images/street.png"],
["example_images/camping.png"],
]
with gr.Blocks() as demo:
gr.Markdown("""
@ -322,14 +334,26 @@ with gr.Blocks() as demo:
""")
with gr.Row():
with gr.Column():
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
with gr.Accordion(
"I2V: Image Input (cannot be used simultaneously with video input)", open=False
):
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False)
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
examples_component_images = gr.Examples(
examples_images, inputs=[image_input], cache_examples=False
)
with gr.Accordion(
"V2V: Video Input (cannot be used simultaneously with image input)", open=False
):
video_input = gr.Video(
label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)"
)
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False)
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
examples_component_videos = gr.Examples(
examples_videos, inputs=[video_input], cache_examples=False
)
prompt = gr.Textbox(
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
)
with gr.Row():
gr.Markdown(
@ -340,11 +364,16 @@ with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
seed_param = gr.Number(
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
label="Inference Seed (Enter a positive number, -1 for random)",
value=-1,
)
with gr.Row():
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
enable_scale = gr.Checkbox(
label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False
)
enable_rife = gr.Checkbox(
label="Frame Interpolation (8fps -> 16fps)", value=False
)
gr.Markdown(
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br>&nbsp;&nbsp;&nbsp;&nbsp;The entire process is based on open-source solutions."
)
@ -430,7 +459,7 @@ with gr.Blocks() as demo:
seed_value,
scale_status,
rife_status,
progress=gr.Progress(track_tqdm=True)
progress=gr.Progress(track_tqdm=True),
):
latents, seed = infer(
prompt,
@ -457,7 +486,9 @@ with gr.Blocks() as demo:
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
batch_video_frames.append(image_pil)
video_path = utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
video_path = utils.save_video(
batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
)
video_update = gr.update(visible=True, value=video_path)
gif_path = convert_to_gif(video_path)
gif_update = gr.update(visible=True, value=gif_path)

View File

@ -15,5 +15,5 @@ gradio>=5.4.0
imageio>=2.34.2
imageio-ffmpeg>=0.5.1
openai>=1.45.0
moviepy>=1.0.3
moviepy>=2.0.0
pillow==9.5.0

View File

@ -3,7 +3,9 @@ from .refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes),
)
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
@ -102,7 +108,9 @@ class IFNet(nn.Module):
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else:
flow_teacher = None
merged_teacher = None
@ -110,11 +118,16 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
(
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float()
.detach()
)
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

View File

@ -3,7 +3,9 @@ from .refine_2R import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes),
)
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
@ -102,7 +108,9 @@ class IFNet(nn.Module):
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else:
flow_teacher = None
merged_teacher = None
@ -110,11 +118,16 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
(
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float()
.detach()
)
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

View File

@ -61,11 +61,19 @@ class IFBlock(nn.Module):
def forward(self, x, flow, scale=1):
x = F.interpolate(
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
x,
scale_factor=1.0 / scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
flow = (
F.interpolate(
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
flow,
scale_factor=1.0 / scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 1.0
/ scale
@ -78,11 +86,21 @@ class IFBlock(nn.Module):
flow = self.conv1(feat)
mask = self.conv2(feat)
flow = (
F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* scale
)
mask = F.interpolate(
mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
mask,
scale_factor=scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
return flow, mask
@ -112,7 +130,11 @@ class IFNet(nn.Module):
loss_cons = 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f0, m0 = block[i](
torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1),
flow,
scale=scale_list[i],
)
f1, m1 = block[i](
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),

View File

@ -3,7 +3,9 @@ from .refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes),
)
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
@ -83,7 +89,9 @@ class IFNet_m(nn.Module):
for i in range(3):
if flow != None:
flow_d, mask_d = stu[i](
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1),
flow,
scale=scale[i],
)
flow = flow + flow_d
mask = mask + mask_d
@ -97,13 +105,17 @@ class IFNet_m(nn.Module):
merged.append(merged_student)
if gt.shape[1] == 3:
flow_d, mask_d = self.block_tea(
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1),
flow,
scale=1,
)
flow_teacher = flow + flow_d
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else:
flow_teacher = None
merged_teacher = None
@ -111,11 +123,16 @@ class IFNet_m(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
(
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float()
.detach()
)
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
if returnflow:
return flow
else:

View File

@ -44,7 +44,9 @@ class Model:
if torch.cuda.is_available():
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
else:
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu")))
self.flownet.load_state_dict(
convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))
)
def save_model(self, path, rank=0):
if rank == 0:

View File

@ -29,10 +29,14 @@ def downsample(x):
def upsample(x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
cc = torch.cat(
[x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
cc = cc.permute(0, 1, 3, 2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3)
cc = torch.cat(
[cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
x_up = cc.permute(0, 1, 3, 2)
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
@ -64,6 +68,10 @@ class LapLoss(torch.nn.Module):
self.gauss_kernel = gauss_kernel(channels=channels)
def forward(self, input, target):
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_input = laplacian_pyramid(
img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
)
pyr_target = laplacian_pyramid(
img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
)
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))

View File

@ -7,7 +7,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
gauss = torch.Tensor(
[exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]
)
return gauss / gauss.sum()
@ -22,7 +24,9 @@ def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
window = (
_3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
)
return window
@ -50,16 +54,35 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
mu1 = F.conv2d(
F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
)
mu2 = F.conv2d(
F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2
sigma1_sq = (
F.conv2d(
F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu1_sq
)
sigma2_sq = (
F.conv2d(
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu2_sq
)
sigma12 = (
F.conv2d(
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
@ -80,7 +103,9 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
return ret
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
def ssim_matlab(
img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None
):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
@ -106,16 +131,35 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full
img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1)
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
mu1 = F.conv3d(
F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
)
mu2 = F.conv3d(
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2
sigma1_sq = (
F.conv3d(
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu1_sq
)
sigma2_sq = (
F.conv3d(
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu2_sq
)
sigma12 = (
F.conv3d(
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
@ -143,7 +187,14 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal
mssim = []
mcs = []
for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
sim, cs = ssim(
img1,
img2,
window_size=window_size,
size_average=size_average,
full=True,
val_range=val_range,
)
mssim.append(sim)
mcs.append(cs)
@ -187,7 +238,9 @@ class SSIM(torch.nn.Module):
self.window = window
self.channel = channel
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
_ssim = ssim(
img1, img2, window=window, window_size=self.window_size, size_average=self.size_average
)
dssim = (1 - _ssim) / 2
return dssim

View File

@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
in_channels=in_planes,
out_channels=out_planes,
kernel_size=4,
stride=2,
padding=1,
bias=True,
),
nn.PReLU(out_planes),
)
@ -56,25 +61,49 @@ class Contextnet(nn.Module):
def forward(self, x, flow):
x = self.conv1(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f1 = warp(x, flow)
x = self.conv2(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f2 = warp(x, flow)
x = self.conv3(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f3 = warp(x, flow)
x = self.conv4(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f4 = warp(x, flow)

View File

@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
in_channels=in_planes,
out_channels=out_planes,
kernel_size=4,
stride=2,
padding=1,
bias=True,
),
nn.PReLU(out_planes),
)
@ -59,19 +64,37 @@ class Contextnet(nn.Module):
f1 = warp(x, flow)
x = self.conv2(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f2 = warp(x, flow)
x = self.conv3(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f3 = warp(x, flow)
x = self.conv4(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f4 = warp(x, flow)

View File

@ -9,6 +9,7 @@ import logging
import skvideo.io
from rife.RIFE_HDv3 import Model
from huggingface_hub import hf_hub_download, snapshot_download
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -78,13 +79,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
# print(f'I1[0] unpadded shape:{I1.shape}')
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if padding[3] > 0 and padding[1] >0 :
frame = I1[:, :, : -padding[3],:-padding[1]]
if padding[3] > 0 and padding[1] > 0:
frame = I1[:, :, : -padding[3], : -padding[1]]
elif padding[3] > 0:
frame = I1[:, :, : -padding[3],:]
elif padding[1] >0:
frame = I1[:, :, :,:-padding[1]]
frame = I1[:, :, : -padding[3], :]
elif padding[1] > 0:
frame = I1[:, :, :, : -padding[1]]
else:
frame = I1
@ -102,7 +102,6 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
frame = F.interpolate(frame, size=(h, w))
output.append(frame.to(output_device))
for i, tmp_frame in enumerate(tmp_output):
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
output.append(tmp_frame.to(output_device))
@ -145,9 +144,7 @@ def rife_inference_with_path(model, video_path):
frame_rgb = frame[..., ::-1]
frame_rgb = frame_rgb.copy()
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
pt_frame_data.append(
tensor.permute(2, 0, 1)
) # to [c, h, w,]
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
pt_frame = pt_frame.to(device)
@ -170,7 +167,9 @@ def rife_inference_with_latents(model, latents):
latent = latents[i]
frames = ssim_interpolation_rife(model, latent)
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
pt_image = torch.stack(
[frames[i].squeeze(0) for i in range(len(frames))]
) # (to [f, c, w, h])
rife_results.append(pt_image)
return torch.stack(rife_results)

View File

@ -22,7 +22,7 @@ def load_torch_file(ckpt, device=None, dtype=torch.float16):
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if not "weights_only" in torch.load.__code__.co_varnames:
if "weights_only" not in torch.load.__code__.co_varnames:
logger.warning(
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
)
@ -74,27 +74,39 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
@torch.inference_mode()
def tiled_scale_multidim(
samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
samples,
function,
tile=(64, 64),
overlap=8,
upscale_amount=4,
out_channels=3,
output_device="cpu",
pbar=None,
):
dims = len(tile)
print(f"samples dtype:{samples.dtype}")
output = torch.empty(
[samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
[samples.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
device=output_device,
)
for b in range(samples.shape[0]):
s = samples[b : b + 1]
out = torch.zeros(
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
[s.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
device=output_device,
)
out_div = torch.zeros(
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
[s.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
device=output_device,
)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
for it in itertools.product(
*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))
):
s_in = s
upscaled = []
@ -142,7 +154,14 @@ def tiled_scale(
pbar=None,
):
return tiled_scale_multidim(
samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar
samples,
function,
(tile_y, tile_x),
overlap,
upscale_amount,
out_channels,
output_device,
pbar,
)
@ -186,7 +205,9 @@ def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu"
return s
def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor:
def upscale_batch_and_concatenate(
upscale_model, latents, inf_device, output_device="cpu"
) -> torch.Tensor:
upscaled_latents = []
for i in range(latents.size(0)):
latent = latents[i]
@ -207,7 +228,9 @@ class ProgressBar:
def __init__(self, total, desc=None):
self.total = total
self.current = 0
self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc)
self.b_unit = tqdm.tqdm(
total=total, desc="ProgressBar context index: 0" if desc is None else desc
)
def update(self, value):
if value > self.total:

View File

@ -20,9 +20,11 @@ from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from datetime import datetime, timedelta
from openai import OpenAI
import moviepy.editor as mp
from moviepy import VideoFileClip
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(
"cuda"
)
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
@ -95,7 +97,12 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
return prompt
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
def infer(
prompt: str,
num_inference_steps: int,
guidance_scale: float,
progress=gr.Progress(track_tqdm=True),
):
torch.cuda.empty_cache()
video = pipe(
prompt=prompt,
@ -117,9 +124,9 @@ def save_video(tensor):
def convert_to_gif(video_path):
clip = mp.VideoFileClip(video_path)
clip = clip.set_fps(8)
clip = clip.resize(height=240)
clip = VideoFileClip(video_path)
clip = clip.with_fps(8)
clip = clip.resized(height=240)
gif_path = video_path.replace(".mp4", ".gif")
clip.write_gif(gif_path, fps=8)
return gif_path
@ -151,7 +158,9 @@ with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
prompt = gr.Textbox(
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
)
with gr.Row():
gr.Markdown(
@ -176,7 +185,13 @@ with gr.Blocks() as demo:
download_video_button = gr.File(label="📥 Download Video", visible=False)
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)):
def generate(
prompt,
num_inference_steps,
guidance_scale,
model_choice,
progress=gr.Progress(track_tqdm=True),
):
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
video_path = save_video(tensor)
video_update = gr.update(visible=True, value=video_path)

View File

@ -1,27 +1,43 @@
[project]
name = "cogvideo"
version = "0.1.0"
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"datasets>=2.14.4",
"decord>=0.6.0",
"deepspeed>=0.16.4",
"diffusers>=0.32.2",
"opencv-python>=4.11.0.86",
"peft>=0.13.2",
"pydantic>=2.10.6",
"sentencepiece>=0.2.0",
"torch>=2.5.1",
"torchvision>=0.20.1",
"transformers>=4.46.3",
"wandb>=0.19.7",
]
[tool.ruff]
line-length = 119
[tool.ruff.lint]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files.
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint.isort]
lines-after-imports = 2
exclude = ['.git', '.mypy_cache', '.ruff_cache', '.venv', 'dist']
target-version = 'py310'
line-length = 100
[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"
line-ending = 'lf'
quote-style = 'preserve'
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
[tool.ruff.lint]
select = ["E", "W"]
ignore = [
"E999",
"EXE001",
"UP009",
"F401",
"TID252",
"F403",
"F841",
"E501",
"W293",
]

View File

@ -1,14 +1,15 @@
diffusers>=0.31.0
accelerate>=1.0.1
transformers>=4.46.1
diffusers>=0.32.1
accelerate>=1.1.1
transformers>=4.46.2
numpy==1.26.0
torch>=2.5.0
torchvision>=0.20.0
sentencepiece>=0.2.0
SwissArmyTransformer>=0.4.12
gradio>=5.4.0
gradio>=5.5.0
imageio>=2.35.1
imageio-ffmpeg>=0.5.1
openai>=1.53.0
moviepy>=1.0.3
openai>=1.54.0
moviepy>=2.0.0
scikit-video>=1.1.11
pydantic>=2.10.3

View File

@ -4,4 +4,3 @@
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
</div>

View File

@ -1,49 +0,0 @@
# Contribution Guide
There may still be many incomplete aspects in this project.
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
and are willing to submit a PR and share it with the community, upon review, we
will acknowledge your contribution on the project homepage.
## Model Algorithms
- Support for model quantization inference (Int4 quantization project)
- Optimization of model fine-tuning data loading (replacing the existing decord tool)
## Model Engineering
- Model fine-tuning examples / Best prompt practices
- Inference adaptation on different devices (e.g., MLX framework)
- Any tools related to the model
- Any minimal fully open-source project using the CogVideoX open-source model
## Code Standards
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
style. You can organize the code according to the following specifications:
1. Install the `ruff` tool
```shell
pip install ruff
```
Then, run the `ruff` tool
```shell
ruff check tools sat inference
```
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
```shell
ruff format tools sat inference
```
Once your code meets the standard, there should be no errors.
## Naming Conventions
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.

View File

@ -1,47 +0,0 @@
# コントリビューションガイド
本プロジェクトにはまだ多くの未完成の部分があります。
以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。
## モデルアルゴリズム
- モデル量子化推論のサポート (Int4量子化プロジェクト)
- モデルのファインチューニングデータロードの最適化既存のdecordツールの置き換え
## モデルエンジニアリング
- モデルのファインチューニング例 / 最適なプロンプトの実践
- 異なるデバイスでの推論適応(例: MLXフレームワーク
- モデルに関連するツール
- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト
## コード標準
良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml`
設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。
1. `ruff` ツールをインストールする
```shell
pip install ruff
```
次に、`ruff` ツールを実行します
```shell
ruff check tools sat inference
```
コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。
```shell
ruff format tools sat inference
```
コードが標準に準拠したら、エラーはなくなるはずです。
## 命名規則
1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。
2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。

View File

@ -1,44 +0,0 @@
# 贡献指南
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区在通过审核后我们将在项目首页感谢您的贡献。
## 模型算法
- 模型量化推理支持 (Int4量化工程)
- 模型微调数据载入优化支持(替换现有的decord工具)
## 模型工程
- 模型微调示例 / 最佳提示词实践
- 不同设备上的推理适配(MLX等框架)
- 任何模型周边工具
- 任何使用CogVideoX开源模型制作的最小完整开源项目
## 代码规范
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
1. 安装`ruff`工具
```shell
pip install ruff
```
接着,运行`ruff`工具
```shell
ruff check tools sat inference
```
检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。
```shell
ruff format tools sat inference
```
如果您的代码符合规范,应该不会出现任何的错误。
## 命名规范
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.6 KiB

After

Width:  |  Height:  |  Size: 195 KiB

View File

@ -1,29 +1,42 @@
# SAT CogVideoX-2B
# SAT CogVideoX
[中文阅读](./README_zh.md)
[日本語で読む](./README_ja.md)
This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with
fine-tuning code for SAT weights.
This code is the framework used by the team to train the model. It has few comments and requires careful study.
If you are interested in the `CogVideoX1.0` version of the model, please check the SAT
folder [here](https://github.com/THUDM/CogVideo/releases/tag/v1.0). This branch only supports the `CogVideoX1.5` series
models.
## Inference Model
### 1. Ensure that you have correctly installed the dependencies required by this folder.
### 1. Make sure you have installed all dependencies in this folder
```shell
```
pip install -r requirements.txt
```
### 2. Download the model weights
### 2. Download the Model Weights
### 2. Download model weights
First, download the model weights from the SAT mirror.
First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows:
#### CogVideoX1.5 Model
```shell
```
git lfs install
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
```
This command downloads three models: Transformers, VAE, and T5 Encoder.
#### CogVideoX Model
For the CogVideoX-2B model, download as follows:
```
mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
@ -34,13 +47,12 @@ mv 'index.html?dl=1' transformer.zip
unzip transformer.zip
```
For the CogVideoX-5B model, please download the `transformers` file as follows link:
(VAE files are the same as 2B)
Download the `transformers` file for the CogVideoX-5B model (the VAE file is the same as for 2B):
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
Next, you need to format the model files as follows:
Arrange the model files in the following structure:
```
.
@ -52,20 +64,24 @@ Next, you need to format the model files as follows:
└── 3d-vae.pt
```
Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be
found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)
Since model weight files are large, its recommended to use `git lfs`.
See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation.
Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well.
```
git lfs install
```
```shell
git clone https://huggingface.co/THUDM/CogVideoX-2b.git
Next, clone the T5 model, which is used as an encoder and doesnt require training or fine-tuning.
> You may also use the model file location on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b).
```
git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Download model from Huggingface
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Download from Modelscope
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
```
By following the above approach, you will obtain a safetensor format T5 file. Ensure that there are no errors when
loading it into Deepspeed in Finetune.
This will yield a safetensor format T5 file that can be loaded without error during Deepspeed fine-tuning.
```
├── added_tokens.json
@ -80,11 +96,11 @@ loading it into Deepspeed in Finetune.
0 directories, 8 files
```
### 3. Modify the file in `configs/cogvideox_2b.yaml`.
### 3. Modify `configs/cogvideox_*.yaml` file.
```yaml
model:
scale_factor: 1.15258426
scale_factor: 1.55258426
disable_first_stage_autocast: true
log_keys:
- txt
@ -160,14 +176,14 @@ model:
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder
model_dir: "t5-v1_1-xxl" # absolute path to CogVideoX-2b/t5-v1_1-xxl weight folder
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # absolute path to CogVideoX-2b-sat/vae/3d-vae.pt file
ignore_keys: [ 'loss' ]
loss_config:
@ -239,19 +255,19 @@ model:
num_steps: 50
```
### 4. Modify the file in `configs/inference.yaml`.
### 4. Modify `configs/inference.yaml` file.
```yaml
args:
latent_channels: 16
mode: inference
load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder
load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1
input_type: txt # You can choose txt for pure text input, or change to cli for command line input
input_file: configs/test.txt # Pure text file, which can be edited
sampling_num_frames: 13 # Must be 13, 11 or 9
input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
input_file: configs/test.txt # Plain text file, can be edited
sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
sampling_fps: 8
fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B
@ -259,28 +275,27 @@ args:
force_inference: True
```
+ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt.
+ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the
OPENAI_API_KEY as your environmental variable.
+ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput.
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are
unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
+ To use command-line input, modify:
```yaml
```
input_type: cli
```
This allows input from the command line as prompts.
This allows you to enter prompts from the command line.
Change `output_dir` if you wish to modify the address of the output video
To modify the output video location, change:
```yaml
```
output_dir: outputs/
```
It is saved by default in the `.outputs/` folder.
The default location is the `.outputs/` folder.
### 5. Run the inference code to perform inference.
### 5. Run the Inference Code to Perform Inference
```shell
```
bash inference.sh
```
@ -288,95 +303,95 @@ bash inference.sh
### Preparing the Dataset
The dataset format should be as follows:
The dataset should be structured as follows:
```
.
├── labels
   ├── 1.txt
   ├── 2.txt
   ├── ...
├── 1.txt
├── 2.txt
├── ...
└── videos
├── 1.mp4
├── 2.mp4
├── ...
```
Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels
should be matched one-to-one. Generally, a single video should not be associated with multiple labels.
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos
and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
For style fine-tuning, please prepare at least 50 videos and labels with similar styles to ensure proper fitting.
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
### Modifying Configuration Files
### Modifying the Configuration File
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Please note that both methods only fine-tune
the `transformer` part and do not modify the `VAE` section. `T5` is used solely as an Encoder. Please modify
the `configs/sft.yaml` (for full-parameter fine-tuning) file as follows:
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the
`transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
```
# checkpoint_activations: True ## Using gradient checkpointing (Both checkpoint_activations in the config file need to be set to True)
```yaml
# checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
model_parallel_size: 1 # Model parallel size
experiment_name: lora-disney # Experiment name (do not modify)
mode: finetune # Mode (do not modify)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
no_load_rng: True # Whether to load random seed
experiment_name: lora-disney # Experiment name (do not change)
mode: finetune # Mode (do not change)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
no_load_rng: True # Whether to load random number seed
train_iters: 1000 # Training iterations
eval_iters: 1 # Evaluation iterations
eval_interval: 100 # Evaluation interval
eval_batch_size: 1 # Evaluation batch size
save: ckpts # Model save path
save_interval: 100 # Model save interval
save_interval: 100 # Save interval
log_interval: 20 # Log output interval
train_data: [ "your train data path" ]
valid_data: [ "your val data path" ] # Training and validation datasets can be the same
split: 1,0,0 # Training, validation, and test set ratio
num_workers: 8 # Number of worker threads for data loader
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE are loaded separately)
only_log_video_latents: True # Avoid memory overhead caused by VAE decode
valid_data: [ "your val data path" ] # Training and validation sets can be the same
split: 1,0,0 # Proportion for training, validation, and test sets
num_workers: 8 # Number of data loader workers
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
only_log_video_latents: True # Avoid memory usage from VAE decoding
deepspeed:
bf16:
enabled: False # For CogVideoX-2B set to False and for CogVideoX-5B set to True
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
enabled: True # For CogVideoX-2B set to True and for CogVideoX-5B set to False
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
```
If you wish to use Lora fine-tuning, you also need to modify the `cogvideox_<model_parameters>_lora` file:
``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
Here, take `CogVideoX-2B` as a reference:
Here's an example using `CogVideoX-2B`:
```
model:
scale_factor: 1.15258426
scale_factor: 1.55258426
disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## Uncomment
not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
log_keys:
- txt'
- txt
lora_config: ## Uncomment
lora_config: ## Uncomment to unlock
target: sat.model.finetune.lora2.LoraMixin
params:
r: 256
```
### Modifying Run Scripts
### Modify the Run Script
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` to select the configuration file. Below are two examples:
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
1. If you want to use the `CogVideoX-2B` model and the `Lora` method, you need to modify `finetune_single_gpu.sh`
or `finetune_multi_gpus.sh`:
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`
as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
```
2. If you want to use the `CogVideoX-2B` model and the `full-parameter fine-tuning` method, you need to
modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`:
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or
`finetune_multi_gpus.sh` as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
```
### Fine-Tuning and Evaluation
### Fine-tuning and Validation
Run the inference code to start fine-tuning.
@ -385,45 +400,45 @@ bash finetune_single_gpu.sh # Single GPU
bash finetune_multi_gpus.sh # Multi GPUs
```
### Using the Fine-Tuned Model
### Using the Fine-tuned Model
The fine-tuned model cannot be merged; here is how to modify the inference configuration file `inference.sh`:
The fine-tuned model cannot be merged. Heres how to modify the inference configuration file `inference.sh`
```
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42"
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
```
Then, execute the code:
Then, run the code:
```
bash inference.sh
```
### Converting to Huggingface Diffusers Supported Weights
### Converting to Huggingface Diffusers-compatible Weights
The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
The SAT weight format is different from Huggingfaces format and requires conversion. Run
```shell
```
python ../tools/convert_weight_sat2hf.py
```
### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints
### Exporting Lora Weights from SAT to Huggingface Diffusers
After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file
at `{args.save}/1000/1000/mp_rank_00_model_states.pt`.
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
After training with the above steps, youll find the SAT model with Lora weights in
{args.save}/1000/1000/mp_rank_00_model_states.pt
The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`.
After exporting, you can use `load_cogvideox_lora.py` for inference.
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting,
use `load_cogvideox_lora.py` for inference.
Export command:
```bash
```
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
```
This training mainly modified the following model structures. The table below lists the corresponding structure mappings
for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the
model's attention structure.
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures.
Lora adds a low-rank weight to the attention structure of the model.
```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
@ -436,5 +451,5 @@ model's attention structure.
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
```
Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format.
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
![alt text](../resources/hf_lora_weights.png)

View File

@ -1,27 +1,39 @@
# SAT CogVideoX-2B
# SAT CogVideoX
[Read this in English.](./README_zh)
[Read this in English.](./README.md)
[中文阅读](./README_zh.md)
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer) ウェイトを使用した推論コードと、SAT
ウェイトのファインチューニングコードが含まれています。
このコードは、チームがモデルをトレーニングするために使用したフレームワークです。コメントが少なく、注意深く研究する必要があります。
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer)の重みを使用した推論コードと、SAT重みのファインチューニングコードが含まれています。
`CogVideoX1.0`バージョンのモデルに関心がある場合は、[こちら](https://github.com/THUDM/CogVideo/releases/tag/v1.0)
のSATフォルダを参照してください。このブランチは`CogVideoX1.5`シリーズのモデルのみをサポートしています。
## 推論モデル
### 1. このフォルダに必要な依存関係が正しくインストールされていることを確認してください
### 1. このフォルダ内の必要な依存関係がすべてインストールされていることを確認してください
```shell
```
pip install -r requirements.txt
```
### 2. モデルウェイトをダウンロードします
### 2. モデルの重みをダウンロード
まず、SAT ミラーに移動してモデルの重みをダウンロードします。 CogVideoX-2B モデルの場合は、次のようにダウンロードしてください。
まず、SATミラーからモデルの重みをダウンロードしてください。
```shell
#### CogVideoX1.5 モデル
```
git lfs install
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
```
これにより、Transformers、VAE、T5 Encoderの3つのモデルがダウンロードされます。
#### CogVideoX モデル
CogVideoX-2B モデルについては、以下のようにダウンロードしてください:
```
mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
@ -32,12 +44,12 @@ mv 'index.html?dl=1' transformer.zip
unzip transformer.zip
```
CogVideoX-5B モデルの `transformers` ファイルを以下のリンクからダウンロードしてください VAE ファイルは 2B と同じです):
CogVideoX-5B モデルの `transformers` ファイルをダウンロードしてくださいVAEファイルは2Bと同じです
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
次に、モデルファイルを以下の形式にフォーマットする必要があります
モデルファイルを以下のように配置してください
```
.
@ -49,24 +61,24 @@ CogVideoX-5B モデルの `transformers` ファイルを以下のリンクから
└── 3d-vae.pt
```
モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs`
のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)を参照ください。
モデルの重みファイルが大きいため、`git lfs`の使用をお勧めします。
`git lfs`のインストール方法は[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)を参照してください。
```shell
```
git lfs install
```
次に、T5 モデルをクローンします。これはトレーニングやファインチューニングには使用されませんが、使用する必要があります
> モデルを複製する際には、[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)のモデルファイルの場所もご使用いただけます。
次に、T5モデルをクローンします。このモデルはEncoderとしてのみ使用され、訓練やファインチューニングは必要ありません
> [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)上のモデルファイルも使用可能です。
```shell
git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ハギングフェイス(huggingface.org)からモデルをダウンロードいただきます
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeからモデルをダウンロードいただきます
```
git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Huggingfaceからモデルをダウンロード
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Modelscopeからダウンロード
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
```
上記の方法に従うことで、safetensor 形式の T5 ファイルを取得できます。これにより、Deepspeed でのファインチューニング中にエラーが発生しないようにします。
これにより、Deepspeedファインチューニング中にエラーなくロードできるsafetensor形式のT5ファイルが作成されます。
```
├── added_tokens.json
@ -81,11 +93,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
0 directories, 8 files
```
### 3. `configs/cogvideox_2b.yaml` ファイルを変更します。
### 3. `configs/cogvideox_*.yaml`ファイルを編集
```yaml
model:
scale_factor: 1.15258426
scale_factor: 1.55258426
disable_first_stage_autocast: true
log_keys:
- txt
@ -123,7 +135,7 @@ model:
num_attention_heads: 30
transformer_args:
checkpoint_activations: True ## グラデーション チェックポイントを使用する
checkpoint_activations: True ## using gradient checkpointing
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
@ -161,14 +173,14 @@ model:
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス
model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl 重みフォルダの絶対パス
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptファイルの絶対パス
ignore_keys: [ 'loss' ]
loss_config:
@ -240,7 +252,7 @@ model:
num_steps: 50
```
### 4. `configs/inference.yaml` ファイルを変更します。
### 4. `configs/inference.yaml`ファイルを編集
```yaml
args:
@ -250,38 +262,38 @@ args:
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1
input_type: txt #TXTのテキストファイルを入力として選択されたり、CLIコマンドラインを入力として変更されたりいただけます
input_file: configs/test.txt #テキストファイルのパスで、これに対して編集がさせていただけます
sampling_num_frames: 13 # Must be 13, 11 or 9
input_type: txt # "txt"でプレーンテキスト入力、"cli"でコマンドライン入力を選択可能
input_file: configs/test.txt # プレーンテキストファイル、編集可能
sampling_num_frames: 13 # CogVideoX1.5-5Bでは42または22、CogVideoX-5B / 2Bでは13, 11, または9
sampling_fps: 8
fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B
fp16: True # CogVideoX-2B
# bf16: True # CogVideoX-5B
output_dir: outputs/
force_inference: True
```
+ 複数のプロンプトを保存するために txt を使用する場合は`configs/test.txt`
参照して変更してください。1行に1つのプロンプトを記述します。プロンプトの書き方がわからない場合は、最初に [のコード](../inference/convert_demo.py)
を使用して LLM によるリファインメントを呼び出すことができます。
+ コマンドラインを入力として使用する場合は、次のように変更します。
+ 複数のプロンプトを含むテキストファイルを使用する場合`configs/test.txt`
適宜編集してください。1行につき1プロンプトです。プロンプトの書き方が分からない場合は、[こちらのコード](../inference/convert_demo.py)
を使用してLLMで補正できます。
+ コマンドライン入力を使用する場合、以下のように変更します:
```yaml
```
input_type: cli
```
これにより、コマンドラインからプロンプトを入力できます。
出力ビデオのディレクトリを変更したい場合は、次のように変更できます
出力ビデオの保存場所を変更する場合は、以下を編集してください
```yaml
```
output_dir: outputs/
```
デフォルトでは `.outputs/` フォルダに保存されます。
デフォルトでは`.outputs/`フォルダに保存されます。
### 5. 推論コードを実行して推論を開始します。
### 5. 推論コードを実行して推論を開始
```shell
```
bash inference.sh
```
@ -289,7 +301,7 @@ bash inference.sh
### データセットの準備
データセットの形式は次のようになります:
データセットは以下の構造である必要があります:
```
.
@ -303,123 +315,224 @@ bash inference.sh
├── ...
```
txt ファイルは対応するビデオファイルと同じ名前であり、そのビデオのラベルを含んでいます。各ビデオはラベルと一対一で対応する必要があります。通常、1つのビデオに複数のラベルを持たせることはありません
各txtファイルは対応するビデオファイルと同じ名前で、ビデオのラベルを含んでいます。ビデオとラベルは一対一で対応させる必要があります。通常、1つのビデオに複数のラベルを使用することは避けてください
スタイルファインチューニングの場合、少なくとも50本のスタイルが似たビデオとラベルを準備し、フィッティングを容易にします。
スタイルのファインチューニングの場合、スタイルが似たビデオとラベルを少なくとも50本準備し、フィッティングを促進します。
### 設定ファイルの変更
### 設定ファイルの編集
`Lora` とフルパラメータ微調整の2つの方法をサポートしています。両方の微調整方法は、`transformer` 部分のみを微調整し、`VAE`
部分には変更を加えないことに注意してください。`T5` はエンコーダーとしてのみ使用されます。以下のように `configs/sft.yaml` (
フルパラメータ微調整用) ファイルを変更してください。
``` `Lora`と全パラメータのファインチューニングの2種類をサポートしています。どちらも`transformer`部分のみをファインチューニングし、`VAE`部分は変更されず、`T5`はエンコーダーとしてのみ使用されます。
``` 以下のようにして`configs/sft.yaml`(全量ファインチューニング)ファイルを編集してください:
```
# checkpoint_activations: True ## 勾配チェックポイントを使用する場合 (設定ファイル内の2つの checkpoint_activations を True に設定する必要があります)
# checkpoint_activations: True ## using gradient checkpointing (configファイル内の2つの`checkpoint_activations`を両方Trueに設定)
model_parallel_size: 1 # モデル並列サイズ
experiment_name: lora-disney # 実験名 (変更しないでください)
mode: finetune # モード (変更しないでください)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer モデルのパス
no_load_rng: True # 乱数シードを読み込むかどうか
experiment_name: lora-disney # 実験名(変更不要)
mode: finetune # モード(変更不要)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformerモデルのパス
no_load_rng: True # 乱数シードをロードするかどうか
train_iters: 1000 # トレーニングイテレーション数
eval_iters: 1 # 評価イテレーション数
eval_interval: 100 # 評価間隔
eval_batch_size: 1 # 評価バッチサイズ
eval_iters: 1 # 検証イテレーション数
eval_interval: 100 # 検証間隔
eval_batch_size: 1 # 検証バッチサイズ
save: ckpts # モデル保存パス
save_interval: 100 # モデル保存間隔
save_interval: 100 # 保存間隔
log_interval: 20 # ログ出力間隔
train_data: [ "your train data path" ]
valid_data: [ "your val data path" ] # トレーニングデータと評価データは同じでも構いません
split: 1,0,0 # トレーニングセット、評価セット、テストセットの割合
num_workers: 8 # データローダーのワーカースレッド
force_train: True # チェックポイントをロードするときに欠落したキーを許可 (T5 と VAE は別々にロードされます)
only_log_video_latents: True # VAE のデコードによるメモリオーバーヘッドを回避
valid_data: [ "your val data path" ] # トレーニングセットと検証セットは同じでも構いません
split: 1,0,0 # トレーニングセット、検証セット、テストセットの割合
num_workers: 8 # データローダーのワーカー数
force_train: True # チェックポイントをロードする際に`missing keys`を許可T5とVAEは別途ロード
only_log_video_latents: True # VAEのデコードによるメモリ使用量を抑える
deepspeed:
bf16:
enabled: False # CogVideoX-2B の場合は False に設定し、CogVideoX-5B の場合は True に設定
enabled: False # CogVideoX-2B 用は False、CogVideoX-5B 用は True に設定
fp16:
enabled: True # CogVideoX-2B の場合は True に設定し、CogVideoX-5B の場合は False に設定
enabled: True # CogVideoX-2B 用は True、CogVideoX-5B 用は False に設定
```
Lora 微調整を使用したい場合は、`cogvideox_<model_parameters>_lora` ファイルも変更する必要があります。
```yaml
args:
latent_channels: 16
mode: inference
load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
ここでは、`CogVideoX-2B` を参考にします。
```
model:
scale_factor: 1.15258426
disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## コメントを解除
log_keys:
- txt'
lora_config: ## コメントを解除
target: sat.model.finetune.lora2.LoraMixin
params:
r: 256
batch_size: 1
input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
input_file: configs/test.txt # Plain text file, can be edited
sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
sampling_fps: 8
fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B
output_dir: outputs/
force_inference: True
```
### 実行スクリプトの変更
設定ファイルを選択するために `finetune_single_gpu.sh` または `finetune_multi_gpus.sh` を編集します。以下に2つの例を示します。
1. `CogVideoX-2B` モデルを使用し、`Lora` 手法を利用する場合は、`finetune_single_gpu.sh` または `finetune_multi_gpus.sh`
を変更する必要があります。
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are
unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
+ To use command-line input, modify:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
input_type: cli
```
2. `CogVideoX-2B` モデルを使用し、`フルパラメータ微調整` 手法を利用する場合は、`finetune_single_gpu.sh`
または `finetune_multi_gpus.sh` を変更する必要があります。
This allows you to enter prompts from the command line.
To modify the output video location, change:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
output_dir: outputs/
```
### 微調整と評価
The default location is the `.outputs/` folder.
推論コードを実行して微調整を開始します。
```
bash finetune_single_gpu.sh # シングルGPU
bash finetune_multi_gpus.sh # マルチGPU
```
### 微調整後のモデルの使用
微調整されたモデルは統合できません。ここでは、推論設定ファイル `inference.sh` を変更する方法を示します。
```
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42"
```
その後、次のコードを実行します。
### 5. Run the Inference Code to Perform Inference
```
bash inference.sh
```
### Huggingface Diffusers サポートのウェイトに変換
## Fine-tuning the Model
SAT ウェイト形式は Huggingface のウェイト形式と異なり、変換が必要です。次のコマンドを実行してください:
### Preparing the Dataset
```shell
The dataset should be structured as follows:
```
.
├── labels
│ ├── 1.txt
│ ├── 2.txt
│ ├── ...
└── videos
├── 1.mp4
├── 2.mp4
├── ...
```
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos
and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
### Modifying the Configuration File
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the
`transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
```yaml
# checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
model_parallel_size: 1 # Model parallel size
experiment_name: lora-disney # Experiment name (do not change)
mode: finetune # Mode (do not change)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
no_load_rng: True # Whether to load random number seed
train_iters: 1000 # Training iterations
eval_iters: 1 # Evaluation iterations
eval_interval: 100 # Evaluation interval
eval_batch_size: 1 # Evaluation batch size
save: ckpts # Model save path
save_interval: 100 # Save interval
log_interval: 20 # Log output interval
train_data: [ "your train data path" ]
valid_data: [ "your val data path" ] # Training and validation sets can be the same
split: 1,0,0 # Proportion for training, validation, and test sets
num_workers: 8 # Number of data loader workers
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
only_log_video_latents: True # Avoid memory usage from VAE decoding
deepspeed:
bf16:
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
```
``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
Here's an example using `CogVideoX-2B`:
```yaml
model:
scale_factor: 1.55258426
disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
log_keys:
- txt
lora_config: ## Uncomment to unlock
target: sat.model.finetune.lora2.LoraMixin
params:
r: 256
```
### Modify the Run Script
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`
as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
```
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or
`finetune_multi_gpus.sh` as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
```
### Fine-tuning and Validation
Run the inference code to start fine-tuning.
```
bash finetune_single_gpu.sh # Single GPU
bash finetune_multi_gpus.sh # Multi GPUs
```
### Using the Fine-tuned Model
The fine-tuned model cannot be merged. Heres how to modify the inference configuration file `inference.sh`
```
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
```
Then, run the code:
```
bash inference.sh
```
### Converting to Huggingface Diffusers-compatible Weights
The SAT weight format is different from Huggingfaces format and requires conversion. Run
```
python ../tools/convert_weight_sat2hf.py
```
### SATチェックポイントからHuggingface Diffusers lora LoRAウェイトをエクスポート
### Exporting Lora Weights from SAT to Huggingface Diffusers
上記のステップを完了すると、LoRAウェイト付きのSATチェックポイントが得られます。ファイルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にあります。
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
After training with the above steps, youll find the SAT model with Lora weights in
{args.save}/1000/1000/mp_rank_00_model_states.pt
LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting,
use `load_cogvideox_lora.py` for inference.
エクスポートコマンド:
Export command:
```bash
```
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
```
このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures.
Lora adds a low-rank weight to the attention structure of the model.
```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
@ -432,7 +545,5 @@ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_nam
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
```
export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
![alt text](../resources/hf_lora_weights.png)

View File

@ -1,12 +1,11 @@
# SAT CogVideoX-2B
# SAT CogVideoX
[Read this in English.](./README_zh)
[Read this in English.](./README.md)
[日本語で読む](./README_ja.md)
本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
如果你关注 `CogVideoX1.0`版本的模型,请查看[这里](https://github.com/THUDM/CogVideo/releases/tag/v1.0)的SAT文件夹该分支仅支持`CogVideoX1.5`系列模型。
## 推理模型
@ -20,6 +19,15 @@ pip install -r requirements.txt
首先,前往 SAT 镜像下载模型权重。
#### CogVideoX1.5 模型
```shell
git lfs install
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
```
此操作会下载 Transformers, VAE, T5 Encoder 这三个模型。
#### CogVideoX 模型
对于 CogVideoX-2B 模型,请按照如下方式下载:
```shell
@ -82,11 +90,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
0 directories, 8 files
```
### 3. 修改`configs/cogvideox_2b.yaml`中的文件。
### 3. 修改`configs/cogvideox_*.yaml`中的文件。
```yaml
model:
scale_factor: 1.15258426
scale_factor: 1.55258426
disable_first_stage_autocast: true
log_keys:
- txt
@ -253,7 +261,7 @@ args:
batch_size: 1
input_type: txt #可以选择txt纯文字档作为输入或者改成cli命令行作为输入
input_file: configs/test.txt #纯文字档,可以对此做编辑
sampling_num_frames: 13 # Must be 13, 11 or 9
sampling_num_frames: 13 #CogVideoX1.5-5B 必须是 42 或 22。 CogVideoX-5B / 2B 必须是 13 11 或 9。
sampling_fps: 8
fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B
@ -346,7 +354,7 @@ Encoder 使用。
```yaml
model:
scale_factor: 1.15258426
scale_factor: 1.55258426
disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## 解除注释
log_keys:

View File

@ -18,7 +18,10 @@ def add_model_config_args(parser):
group = parser.add_argument_group("model", "model configuration")
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
group.add_argument(
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
"--model-parallel-size",
type=int,
default=1,
help="size of the model parallel. only use if you are an expert.",
)
group.add_argument("--force-pretrain", action="store_true")
group.add_argument("--device", type=int, default=-1)
@ -36,6 +39,7 @@ def add_sampling_config_args(parser):
group.add_argument("--input-dir", type=str, default=None)
group.add_argument("--input-type", type=str, default="cli")
group.add_argument("--input-file", type=str, default="input.txt")
group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
group.add_argument("--final-size", type=int, default=2048)
group.add_argument("--sdedit", action="store_true")
group.add_argument("--grid-num-rows", type=int, default=1)
@ -73,10 +77,15 @@ def get_args(args_list=None, parser=None):
if not args.train_data:
print_rank0("No training data specified", level="WARNING")
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
assert (args.train_iters is None) or (
args.epochs is None
), "only one of train_iters and epochs should be set."
if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
print_rank0(
"No train_iters (recommended) or epochs specified, use default 10k iters.",
level="WARNING",
)
args.cuda = torch.cuda.is_available()
@ -212,7 +221,10 @@ def initialize_distributed(args):
args.master_port = os.getenv("MASTER_PORT", default_master_port)
init_method += args.master_ip + ":" + args.master_port
torch.distributed.init_process_group(
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
)
# Set the model-parallel / data-parallel communicators.
@ -231,7 +243,10 @@ def initialize_distributed(args):
import deepspeed
deepspeed.init_distributed(
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
dist_backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
)
# # It seems that it has no negative influence to configure it even without using checkpointing.
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
@ -261,7 +276,9 @@ def process_config_to_args(args):
args_config = config.pop("args", OmegaConf.create())
for key in args_config:
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
args_config[key], omegaconf.ListConfig
):
arg = OmegaConf.to_object(args_config[key])
else:
arg = args_config[key]

View File

@ -0,0 +1,149 @@
model:
scale_factor: 0.7
disable_first_stage_autocast: true
latent_input: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
network_config:
target: dit_video_concat.DiffusionTransformer
params:
time_embed_dim: 512
elementwise_affine: True
num_frames: 81 # for 5 seconds and 161 for 10 seconds
time_compressed_rate: 4
latent_width: 300
latent_height: 300
num_layers: 42
patch_size: [2, 2, 2]
in_channels: 16
out_channels: 16
hidden_size: 3072
adm_in_channels: 256
num_attention_heads: 48
transformer_args:
checkpoint_activations: True
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
params:
hidden_size_head: 64
text_length: 224
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 224
first_stage_config:
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt"
ignore_keys: ['loss']
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
params:
offset_noise_level: 0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: True
group_num: 40
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
params:
num_steps: 50
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
scale: 6
exp: 5
num_steps: 50

View File

@ -0,0 +1,159 @@
model:
scale_factor: 0.7
disable_first_stage_autocast: true
latent_input: false
noised_image_input: true
noised_image_all_concat: false
noised_image_dropout: 0.05
augmentation_dropout: 0.15
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
network_config:
target: dit_video_concat.DiffusionTransformer
params:
ofs_embed_dim: 512
time_embed_dim: 512
elementwise_affine: True
num_frames: 81 # for 5 seconds and 161 for 10 seconds
time_compressed_rate: 4
latent_width: 300
latent_height: 300
num_layers: 42
patch_size: [2, 2, 2]
in_channels: 32
out_channels: 16
hidden_size: 3072
adm_in_channels: 256
num_attention_heads: 48
transformer_args:
checkpoint_activations: True
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
params:
hidden_size_head: 64
text_length: 224
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 224
first_stage_config:
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt"
ignore_keys: ['loss']
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
params:
fixed_frames: 0
offset_noise_level: 0.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: True
group_num: 40
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
params:
fixed_frames: 0
num_steps: 50
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
scale: 6
exp: 5
num_steps: 50

View File

@ -1,16 +1,14 @@
args:
image2video: False # True for image2video, False for text2video
# image2video: True # True for image2video, False for text2video
latent_channels: 16
mode: inference
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1
input_type: txt
input_file: configs/test.txt
sampling_image_size: [480, 720]
sampling_num_frames: 13 # Must be 13, 11 or 9
sampling_fps: 8
# fp16: True # For CogVideoX-2B
bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V
output_dir: outputs/
sampling_image_size: [768, 1360] # remove this for I2V
sampling_num_frames: 22 # 42 for 10 seconds and 22 for 5 seconds
sampling_fps: 16
bf16: True
output_dir: outputs
force_inference: True

View File

@ -1,4 +1,4 @@
In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.

View File

@ -56,7 +56,9 @@ def read_video(
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
audio_frames = []
@ -342,7 +344,11 @@ class VideoDataset(MetaDistributedWebDataset):
super().__init__(
path,
partial(
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
process_fn_video,
num_frames=num_frames,
image_size=image_size,
fps=fps,
skip_frms_num=skip_frms_num,
),
seed,
meta_names=meta_names,
@ -400,7 +406,9 @@ class SFTDataset(Dataset):
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
temp_frms = vr.get_batch(np.arange(start, end_safty))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = (
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
)
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
if ori_vlen > self.max_num_frames:
@ -410,7 +418,11 @@ class SFTDataset(Dataset):
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = (
torch.from_numpy(temp_frms)
if type(temp_frms) is not torch.Tensor
else temp_frms
)
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
@ -423,11 +435,17 @@ class SFTDataset(Dataset):
start = int(self.skip_frms_num)
end = int(ori_vlen - self.skip_frms_num)
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
num_frames = nearest_smaller_4k_plus_1(
end - start
) # 3D VAE requires the number of frames to be 4k+1
end = int(start + num_frames)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = (
torch.from_numpy(temp_frms)
if type(temp_frms) is not torch.Tensor
else temp_frms
)
tensor_frms = pad_last_frame(
tensor_frms, self.max_num_frames

View File

@ -10,7 +10,6 @@ import torch
from torch import nn
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import (
default,
@ -42,7 +41,9 @@ class SATVideoDiffusionEngine(nn.Module):
latent_input = model_config.get("latent_input", False)
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
no_cond_log = model_config.get("disable_first_stage_autocast", False)
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
not_trainable_prefixes = model_config.get(
"not_trainable_prefixes", ["first_stage_model", "conditioner"]
)
compile_model = model_config.get("compile_model", False)
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
lr_scale = model_config.get("lr_scale", None)
@ -77,12 +78,18 @@ class SATVideoDiffusionEngine(nn.Module):
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
self.sampler = (
instantiate_from_config(sampler_config) if sampler_config is not None else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
self.loss_fn = (
instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
)
self.latent_input = latent_input
self.scale_factor = scale_factor
@ -90,8 +97,18 @@ class SATVideoDiffusionEngine(nn.Module):
self.no_cond_log = no_cond_log
self.device = args.device
# put lora add here
def disable_untrainable_params(self):
total_trainable = 0
if self.lora_train:
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
if 'lora_layer' not in n:
p.lr_scale = 0
else:
total_trainable += p.numel()
else:
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
@ -101,7 +118,7 @@ class SATVideoDiffusionEngine(nn.Module):
flag = True
break
lora_prefix = ["matrix_A", "matrix_B"]
lora_prefix = ['matrix_A', 'matrix_B']
for prefix in lora_prefix:
if prefix in n:
flag = False
@ -142,8 +159,12 @@ class SATVideoDiffusionEngine(nn.Module):
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
lr_x = F.interpolate(
x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False
)
lr_x = F.interpolate(
lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False
)
lr_z = self.encode_first_stage(lr_x, batch)
batch["lr_input"] = lr_z
@ -179,14 +200,31 @@ class SATVideoDiffusionEngine(nn.Module):
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
all_out.append(out)
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
latent_time = z_now.shape[2] # check the time latent
fake_cp_size = min(10, latent_time // 2)
recons = []
start_frame = 0
for i in range(fake_cp_size):
end_frame = (
start_frame
+ latent_time // fake_cp_size
+ (1 if i < latent_time % fake_cp_size else 0)
)
use_cp = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
with torch.no_grad():
recon = self.first_stage_model.decode(
z_now[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache,
use_cp=use_cp,
)
recons.append(recon)
start_frame = end_frame
recons = torch.cat(recons, dim=2)
all_out.append(recons)
out = torch.cat(all_out, dim=0)
return out
@ -218,6 +256,7 @@ class SATVideoDiffusionEngine(nn.Module):
shape: Union[None, Tuple, List] = None,
prefix=None,
concat_images=None,
ofs=None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
@ -241,7 +280,9 @@ class SATVideoDiffusionEngine(nn.Module):
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
)
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
samples = self.sampler(
denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs
)
samples = samples.to(self.dtype)
return samples
@ -255,7 +296,9 @@ class SATVideoDiffusionEngine(nn.Module):
log = dict()
for embedder in self.conditioner.embedders:
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
if (
(self.log_keys is None) or (embedder.input_key in self.log_keys)
) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
@ -331,7 +374,9 @@ class SATVideoDiffusionEngine(nn.Module):
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
c["concat"] = image
uc["concat"] = image
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples
@ -341,7 +386,9 @@ class SATVideoDiffusionEngine(nn.Module):
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
log["samples"] = samples
else:
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples

View File

@ -1,11 +1,12 @@
from functools import partial
from einops import rearrange, repeat
from functools import reduce
from operator import mul
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from sat.model.base_model import BaseModel, non_conflict
from sat.model.mixins import BaseMixin
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
@ -13,38 +14,35 @@ from sat.mpu.layers import ColumnParallelLinear
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.openaimodel import Timestep
from sgm.modules.diffusionmodules.util import (
linear,
timestep_embedding,
)
from sgm.modules.diffusionmodules.util import linear, timestep_embedding
from sat.ops.layernorm import LayerNorm, RMSNorm
class ImagePatchEmbeddingMixin(BaseMixin):
def __init__(
self,
in_channels,
hidden_size,
patch_size,
bias=True,
text_hidden_size=None,
):
def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
super().__init__()
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
else:
self.text_proj = None
def word_embedding_forward(self, input_ids, **kwargs):
# now is 3d patch
images = kwargs["images"] # (b,t,c,h,w)
B, T = images.shape[:2]
emb = images.view(-1, *images.shape[2:])
emb = self.proj(emb) # ((b t),d,h/2,w/2)
emb = emb.view(B, T, *emb.shape[1:])
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
emb = rearrange(emb, "b t n d -> b (t n) d")
emb = rearrange(images, "b t c h w -> b (t h w) c")
# emb = rearrange(images, "b c t h w -> b (t h w) c")
emb = rearrange(
emb,
"b (t o h p w q) c -> b (t h w) (c o p q)",
t=kwargs["rope_T"],
h=kwargs["rope_H"],
w=kwargs["rope_W"],
o=self.patch_size[0],
p=self.patch_size[1],
q=self.patch_size[2],
)
emb = self.proj(emb)
if self.text_proj is not None:
text_emb = self.text_proj(kwargs["encoder_outputs"])
@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed(
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
(w/ or w/o cls_token)
"""
assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3
@ -95,12 +94,13 @@ def get_3d_sincos_pos_embed(
# concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
pos_embed_temporal = np.repeat(
pos_embed_temporal, grid_height * grid_width, axis=1
) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
return pos_embed # [T, H*W, D]
@ -162,7 +162,8 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
self.width = width
self.spatial_length = height * width
self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)),
requires_grad=False,
)
def position_embedding_forward(self, position_ids, **kwargs):
@ -171,7 +172,9 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
self.pos_embedding.data[:, -self.spatial_length :].copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
class Basic3DPositionEmbeddingMixin(BaseMixin):
@ -194,7 +197,8 @@ class Basic3DPositionEmbeddingMixin(BaseMixin):
self.spatial_length = height * width
self.num_patches = height * width * compressed_num_frames
self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)),
requires_grad=False,
)
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
@ -259,6 +263,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
text_length,
theta=10000,
rot_v=False,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
learnable_pos_embed=False,
):
super().__init__()
@ -284,32 +291,38 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = broadcat(
(freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]),
dim=-1,
)
freqs = freqs.contiguous()
freqs_sin = freqs.sin()
freqs_cos = freqs.cos()
self.register_buffer("freqs_sin", freqs_sin)
self.register_buffer("freqs_cos", freqs_cos)
self.freqs_sin = freqs.sin().cuda()
self.freqs_cos = freqs.cos().cuda()
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
self.pos_embedding = nn.Parameter(
torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True
)
else:
self.pos_embedding = None
def rotary(self, t, **kwargs):
seq_len = t.shape[2]
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
def reshape_freq(freqs):
freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None:
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]]
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
else:
return None
@ -326,10 +339,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
):
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
query_layer = torch.cat(
(
query_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
query_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
key_layer = torch.cat(
(
key_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
key_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
if self.rot_v:
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
value_layer = torch.cat(
(
value_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
value_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
return attention_fn_default(
query_layer,
@ -347,21 +411,25 @@ def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
def unpatchify(x, c, patch_size, w, h, **kwargs):
"""
x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C)
patch_size 被拆解为三个不同的维度 (o, p, q)分别对应了深度o高度p和宽度q这使得 patch 大小在不同维度上可以不相等增加了灵活性
"""
if rope_position_ids is not None:
assert NotImplementedError
# do pix2struct unpatchify
L = x.shape[1]
x = x.reshape(shape=(x.shape[0], L, p, p, c))
x = torch.einsum("nlpqc->ncplq", x)
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
else:
b = x.shape[0]
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
imgs = rearrange(
x,
"b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
c=c,
o=patch_size[0],
p=patch_size[1],
q=patch_size[2],
t=kwargs["rope_T"],
h=kwargs["rope_H"],
w=kwargs["rope_W"],
)
return imgs
@ -382,15 +450,16 @@ class FinalLayerMixin(BaseMixin):
self.patch_size = patch_size
self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
self.spatial_length = latent_width * latent_height // patch_size**2
self.latent_width = latent_width
self.latent_height = latent_height
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)
)
def final_forward(self, logits, **kwargs):
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
x, emb = (
logits[:, kwargs["text_length"] :, :],
kwargs["emb"],
) # x:(b,(t n),d),只取了x中后面images的部分
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
@ -398,10 +467,9 @@ class FinalLayerMixin(BaseMixin):
return unpatchify(
x,
c=self.out_channels,
p=self.patch_size,
w=self.latent_width // self.patch_size,
h=self.latent_height // self.patch_size,
rope_position_ids=kwargs.get("rope_position_ids", None),
patch_size=self.patch_size,
w=kwargs["rope_W"],
h=kwargs["rope_H"],
**kwargs,
)
@ -440,8 +508,6 @@ class SwiGLUMixin(BaseMixin):
class AdaLNMixin(BaseMixin):
def __init__(
self,
width,
height,
hidden_size,
num_layers,
time_embed_dim,
@ -452,12 +518,13 @@ class AdaLNMixin(BaseMixin):
):
super().__init__()
self.num_layers = num_layers
self.width = width
self.height = height
self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList(
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
[
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size))
for _ in range(num_layers)
]
)
self.qk_ln = qk_ln
@ -517,7 +584,9 @@ class AdaLNMixin(BaseMixin):
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
attention_input = torch.cat(
(text_attention_input, img_attention_input), dim=1
) # (b,n_t+t*n_i,d)
attention_output = layer.attention(attention_input, mask, **kwargs)
text_attention_output = attention_output[:, :text_length] # (b,n,d)
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
@ -541,9 +610,13 @@ class AdaLNMixin(BaseMixin):
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
text_hidden_states = (
text_hidden_states + text_gate_mlp * text_mlp_output
) # language (b,n,d)
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
hidden_states = torch.cat(
(text_hidden_states, img_hidden_states), dim=1
) # (b,(n_t+t*n_i),d)
return hidden_states
def reinit(self, parent_model=None):
@ -611,7 +684,7 @@ class DiffusionTransformer(BaseModel):
time_interpolation=1.0,
use_SwiGLU=False,
use_RMSNorm=False,
zero_init_y_embed=False,
ofs_embed_dim=None,
**kwargs,
):
self.latent_width = latent_width
@ -619,12 +692,13 @@ class DiffusionTransformer(BaseModel):
self.patch_size = patch_size
self.num_frames = num_frames
self.time_compressed_rate = time_compressed_rate
self.spatial_length = latent_width * latent_height // patch_size**2
self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.ofs_embed_dim = ofs_embed_dim
self.num_classes = num_classes
self.adm_in_channels = adm_in_channels
self.input_time = input_time
@ -636,7 +710,6 @@ class DiffusionTransformer(BaseModel):
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
self.inner_hidden_size = hidden_size * 4
self.zero_init_y_embed = zero_init_y_embed
try:
self.dtype = str_to_dtype[kwargs.pop("dtype")]
except:
@ -651,7 +724,9 @@ class DiffusionTransformer(BaseModel):
if use_RMSNorm:
kwargs["layernorm"] = RMSNorm
else:
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
kwargs["layernorm"] = partial(
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6
)
transformer_args.num_layers = num_layers
transformer_args.hidden_size = hidden_size
@ -664,12 +739,13 @@ class DiffusionTransformer(BaseModel):
if use_SwiGLU:
self.add_mixin(
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
"swiglu",
SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False),
reinit=True,
)
def _build_modules(self, module_configs):
model_channels = self.hidden_size
# time_embed_dim = model_channels * 4
time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
@ -677,6 +753,13 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim),
)
if self.ofs_embed_dim is not None:
self.ofs_embed = nn.Sequential(
linear(self.ofs_embed_dim, self.ofs_embed_dim),
nn.SiLU(),
linear(self.ofs_embed_dim, self.ofs_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
@ -701,9 +784,6 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim),
)
)
if self.zero_init_y_embed:
nn.init.constant_(self.label_emb[0][2].weight, 0)
nn.init.constant_(self.label_emb[0][2].bias, 0)
else:
raise ValueError()
@ -712,10 +792,13 @@ class DiffusionTransformer(BaseModel):
"pos_embed",
instantiate_from_config(
pos_embed_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
height=self.latent_height // self.patch_size[1],
width=self.latent_width // self.patch_size[2],
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size=self.hidden_size,
height_interpolation=self.height_interpolation,
width_interpolation=self.width_interpolation,
time_interpolation=self.time_interpolation,
),
reinit=True,
)
@ -737,8 +820,6 @@ class DiffusionTransformer(BaseModel):
"adaln_layer",
instantiate_from_config(
adaln_layer_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
@ -749,7 +830,6 @@ class DiffusionTransformer(BaseModel):
)
else:
raise NotImplementedError
final_layer_config = module_configs["final_layer_config"]
self.add_mixin(
"final_layer",
@ -765,44 +845,55 @@ class DiffusionTransformer(BaseModel):
),
reinit=True,
)
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
self.add_mixin(
"lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True
)
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
# This is not use in inference
if "concat_images" in kwargs and kwargs["concat_images"] is not None:
if kwargs["concat_images"].shape[0] != x.shape[0]:
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
else:
concat_images = kwargs["concat_images"]
x = torch.cat([x, concat_images], dim=2)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
# assert y.shape[0] == x.shape[0]
assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y)
kwargs["seq_length"] = t * h * w // (self.patch_size**2)
if self.ofs_embed_dim is not None:
ofs_emb = timestep_embedding(
kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype
)
ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
kwargs["images"] = x
kwargs["emb"] = emb
kwargs["encoder_outputs"] = context
kwargs["text_length"] = context.shape[1]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
kwargs["rope_T"] = t // self.patch_size[0]
kwargs["rope_H"] = h // self.patch_size[1]
kwargs["rope_W"] = w // self.patch_size[2]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones(
(1, 1)
).to(x.dtype)
output = super().forward(**kwargs)[0]
return output

View File

@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/inference.yaml --seed $RANDOM"
echo ${run_cmd}
eval ${run_cmd}

View File

@ -1,16 +1,11 @@
SwissArmyTransformer==0.4.12
omegaconf==2.3.0
torch==2.4.0
torchvision==0.19.0
pytorch_lightning==2.3.3
kornia==0.7.3
beartype==0.18.5
numpy==2.0.1
fsspec==2024.5.0
safetensors==0.4.3
imageio-ffmpeg==0.5.1
imageio==2.34.2
scipy==1.14.0
decord==0.6.0
wandb==0.17.5
deepspeed==0.14.4
SwissArmyTransformer>=0.4.12
omegaconf>=2.3.0
pytorch_lightning>=2.4.0
kornia>=0.7.3
beartype>=0.19.0
fsspec>=2024.2.0
safetensors>=0.4.5
scipy>=1.14.1
decord>=0.6.0
wandb>=0.18.5
deepspeed>=0.15.3

View File

@ -4,23 +4,20 @@ import argparse
from typing import List, Union
from tqdm import tqdm
from omegaconf import ListConfig
from PIL import Image
import imageio
import torch
import numpy as np
from einops import rearrange
from einops import rearrange, repeat
import torchvision.transforms as TT
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from sat import mpu
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
from PIL import Image
def read_from_cli():
@ -54,8 +51,60 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
)
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
device, dtype=torch.half
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
else:
batch[key] = value_dict[key]
@ -68,7 +117,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
return batch, batch_uc
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
):
os.makedirs(save_path, exist_ok=True)
for i, vid in enumerate(video_batch):
@ -83,37 +134,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i
writer.append_data(frame)
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def sampling_main(args, model_cls):
if isinstance(model_cls, type):
model = get_model(args, model_cls)
@ -127,40 +147,67 @@ def sampling_main(args, model_cls):
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
image_size = [480, 720]
if args.image2video:
chained_trainsforms = []
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
sample_func = model.sample
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
device = model.device
T, C = args.sampling_num_frames, args.latent_channels
with torch.no_grad():
for text, cnt in tqdm(data_iter):
if args.image2video:
# use with input image shape
text, image_path = text.split("@@")
assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert("RGB")
(img_W, img_H) = image.size
def nearest_multiple_of_16(n):
lower_multiple = (n // 16) * 16
upper_multiple = (n // 16 + 1) * 16
if abs(n - lower_multiple) < abs(n - upper_multiple):
return lower_multiple
else:
return upper_multiple
if img_H < img_W:
H = 96
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
else:
W = 96
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
chained_trainsforms = []
chained_trainsforms.append(
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
)
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to("cuda")
image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0)
image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None)
image = image / model.scale_factor
image = image.permute(0, 2, 1, 3, 4).contiguous()
pad_shape = (image.shape[0], T - 1, C, H // F, W // F)
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
pad_shape = (image.shape[0], T - 1, C, H, W)
image = torch.concat(
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
)
else:
image_size = args.sampling_image_size
H, W = image_size[0], image_size[1]
F = 8 # 8x downsampled
image = None
text_cast = [text]
mp_size = mpu.get_model_parallel_world_size()
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast_object_list(
text_cast, src=src, group=mpu.get_model_parallel_group()
)
text = text_cast[0]
value_dict = {
"prompt": text,
"negative_prompt": "",
@ -168,7 +215,9 @@ def sampling_main(args, model_cls):
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
@ -187,54 +236,47 @@ def sampling_main(args, model_cls):
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
if args.image2video and image is not None:
if args.image2video:
c["concat"] = image
uc["concat"] = image
for index in range(args.batch_size):
# reload model on GPU
model.to(device)
if args.image2video:
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H, W),
ofs=torch.tensor([2.0]).to("cuda"),
)
else:
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
)
).to("cuda")
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# Unload the model from GPU to save GPU memory
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
if args.only_save_latents:
samples_z = 1.0 / model.scale_factor * samples_z
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
)
os.makedirs(save_path, exist_ok=True)
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
with open(os.path.join(save_path, "text.txt"), "w") as f:
f.write(text)
else:
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join(
args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)

View File

@ -71,15 +71,24 @@ class LambdaWarmUpCosineScheduler2:
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
@ -93,10 +102,15 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:

View File

@ -218,14 +218,20 @@ class AutoencodingEngine(AbstractAutoencoder):
x = self.decoder(z, **kwargs)
return x
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
@ -361,12 +367,16 @@ class AutoencodingEngine(AbstractAutoencoder):
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
@ -375,17 +385,23 @@ class AutoencodingEngine(AbstractAutoencoder):
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts.append(opt_disc)
return opts
@torch.no_grad()
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x
@ -404,7 +420,9 @@ class AutoencodingEngine(AbstractAutoencoder):
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
@ -446,7 +464,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
params = super().get_autoencoder_params()
return params
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
@ -513,7 +533,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
def log_videos(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor:
@ -524,7 +546,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
torch.distributed.broadcast(
batch, src=global_src_rank, group=get_context_parallel_group()
)
batch = _conv_split(batch, dim=2, kernel_size=1)
return batch

View File

@ -94,7 +94,11 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
# new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h)
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x
)
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
sdp_backend=None,
):
super().__init__()
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))

View File

@ -87,7 +87,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
yield from ()
@torch.no_grad()
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
def log_images(
self, inputs: torch.Tensor, reconstructions: torch.Tensor
) -> Dict[str, torch.Tensor]:
# calc logits of real/fake
logits_real = self.discriminator(inputs.contiguous().detach())
if len(logits_real.shape) < 4:
@ -209,7 +211,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
weights: Union[None, float, torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
@ -226,7 +230,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices)
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
p_loss = self.perceptual_loss(
input_frames.contiguous(), recon_frames.contiguous()
).mean()
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
@ -238,7 +244,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.training:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
else:
d_weight = torch.tensor(1.0)
else:

View File

@ -37,12 +37,18 @@ class LatentLPIPS(nn.Module):
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
@ -58,7 +64,9 @@ class LatentLPIPS(nn.Module):
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log

Some files were not shown because too many files have changed in this diff Show More