13 Commits

Author SHA1 Message Date
OleehyO
30ba1085ff Merge remote-tracking branch 'upstream/main' into dev 2025-01-12 05:58:07 +00:00
OleehyO
fdb9820949 feat: support DeepSpeed ZeRO-3 and optimize peak memory usage
- Add DeepSpeed ZeRO-3 configuration support
- Optimize memory usage during training
- Rename training scripts to reflect ZeRO usage
- Update related configuration files and trainers
2025-01-12 05:33:56 +00:00
Zheng Guang Cong
09a49d3546
fix bug of i2v; video is already 0-255
video is already 0-255 and should not be multiplied 255 any more
2025-01-11 17:29:27 +08:00
OleehyO
caa24bdc36 feat: add SFT support with ZeRO optimization strategies
- Add SFT (Supervised Fine-Tuning) trainers for all model variants:
  - CogVideoX I2V and T2V
  - CogVideoX-1.5 I2V and T2V
- Add DeepSpeed ZeRO configuration files:
  - ZeRO-2 with and without CPU offload
  - ZeRO-3 with and without CPU offload
- Add base accelerate config for distributed training
- Update trainer.py to support SFT training mode

This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
2025-01-11 02:13:32 +00:00
OleehyO
10de04fc08 perf: cast VAE and text encoder to target dtype before precomputing cache
Before precomputing the latent cache and text embeddings, cast the VAE and
text encoder to the target training dtype (fp16/bf16) instead of keeping them
in fp32. This reduces memory usage during the precomputation phase.

The change occurs in prepare_dataset() where the models are moved to device
and cast to weight_dtype before being used to generate the cache.
2025-01-08 01:38:13 +00:00
OleehyO
36427274d6 style: format import statements across finetune module 2025-01-07 05:54:52 +00:00
zR
1789f07256 format and check fp16 for cogvideox2b 2025-01-07 13:16:18 +08:00
OleehyO
49dc370de6 fix: remove pipeline hooks after validation
- Add pipe.remove_all_hooks() after validation to prevent memory leaks
- Clean up validation pipeline properly to avoid potential issues in subsequent training steps
2025-01-04 06:21:17 +00:00
OleehyO
e5b8f9a2ee feat: add caching for prompt embeddings
- Add caching for prompt embeddings
- Store cached files using safetensors format
- Add cache directory structure under data_root/cache
- Optimize memory usage by moving tensors to CPU after caching
- Add debug logging for cache hits
- Add info logging for cache writes

The caching system helps reduce redundant computation and memory usage during training by:
1. Caching prompt embeddings based on prompt text hash
2. Caching encoded video latents based on video filename
3. Moving tensors to CPU after caching to free GPU memory
2025-01-04 06:16:31 +00:00
OleehyO
a001842834 feat: implement CogVideoX trainers for I2V and T2V tasks
Add and refactor trainers for CogVideoX model variants:
- Implement CogVideoXT2VLoraTrainer for text-to-video generation
- Refactor CogVideoXI2VLoraTrainer for image-to-video generation

Both trainers support LoRA fine-tuning with proper handling of:
- Model components loading and initialization
- Video encoding and batch collation
- Loss computation with noise prediction
- Validation step for generation
2025-01-01 15:10:54 +00:00
OleehyO
45d40450a1 refactor: simplify dataset implementation and add latent precomputation
- Replace bucket-based dataset with simpler resize-based implementation
- Add video latent precomputation during dataset initialization
- Improve code readability and user experience
- Remove complexity of bucket sampling for better maintainability

This change makes the codebase more straightforward and easier to use while
maintaining functionality through resize-based video processing.
2025-01-01 15:10:54 +00:00
OleehyO
fa4659fb2c feat(trainer): add validation functionality to Trainer class
Add validation capabilities to the Trainer class including:
- Support for validating images and videos during training
- Periodic validation based on validation_steps parameter
- Artifact logging to wandb for validation results
- Memory tracking during validation process
2025-01-01 15:10:41 +00:00
OleehyO
60f6a3d7ee feat: add base trainer implementation and training script
- Add Trainer base class with core training loop functionality
- Implement distributed training setup with Accelerate
- Add training script with model/trainer initialization
- Support LoRA fine-tuning with checkpointing and validation
2025-01-01 15:10:41 +00:00