mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-28 13:25:21 +08:00
- Add SFT (Supervised Fine-Tuning) trainers for all model variants: - CogVideoX I2V and T2V - CogVideoX-1.5 I2V and T2V - Add DeepSpeed ZeRO configuration files: - ZeRO-2 with and without CPU offload - ZeRO-3 with and without CPU offload - Add base accelerate config for distributed training - Update trainer.py to support SFT training mode This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
28 lines
1007 B
Python
28 lines
1007 B
Python
import dotmap
|
|
from diffusers import CogVideoXImageToVideoPipeline
|
|
from typing_extensions import override
|
|
|
|
from finetune.utils import unwrap_model
|
|
|
|
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
|
from ..utils import register
|
|
|
|
|
|
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
|
@override
|
|
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
|
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
|
self.components.transformer.config.update(origin_model.config)
|
|
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
|
pipe = CogVideoXImageToVideoPipeline(
|
|
tokenizer=self.components.tokenizer,
|
|
text_encoder=self.components.text_encoder,
|
|
vae=self.components.vae,
|
|
transformer=self.components.transformer,
|
|
scheduler=self.components.scheduler,
|
|
)
|
|
return pipe
|
|
|
|
|
|
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|