mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-28 21:29:21 +08:00
- Add SFT (Supervised Fine-Tuning) trainers for all model variants: - CogVideoX I2V and T2V - CogVideoX-1.5 I2V and T2V - Add DeepSpeed ZeRO configuration files: - ZeRO-2 with and without CPU offload - ZeRO-3 with and without CPU offload - Add base accelerate config for distributed training - Update trainer.py to support SFT training mode This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
28 lines
971 B
Python
28 lines
971 B
Python
import dotmap
|
|
from diffusers import CogVideoXPipeline
|
|
from typing_extensions import override
|
|
|
|
from finetune.utils import unwrap_model
|
|
|
|
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
|
from ..utils import register
|
|
|
|
|
|
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
|
@override
|
|
def initialize_pipeline(self) -> CogVideoXPipeline:
|
|
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
|
self.components.transformer.config.update(origin_model.config)
|
|
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
|
pipe = CogVideoXPipeline(
|
|
tokenizer=self.components.tokenizer,
|
|
text_encoder=self.components.text_encoder,
|
|
vae=self.components.vae,
|
|
transformer=self.components.transformer,
|
|
scheduler=self.components.scheduler,
|
|
)
|
|
return pipe
|
|
|
|
|
|
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|