feat: add SFT support with ZeRO optimization strategies

- Add SFT (Supervised Fine-Tuning) trainers for all model variants:
  - CogVideoX I2V and T2V
  - CogVideoX-1.5 I2V and T2V
- Add DeepSpeed ZeRO configuration files:
  - ZeRO-2 with and without CPU offload
  - ZeRO-3 with and without CPU offload
- Add base accelerate config for distributed training
- Update trainer.py to support SFT training mode

This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
This commit is contained in:
OleehyO 2025-01-11 02:13:32 +00:00
parent e213b6c083
commit caa24bdc36
11 changed files with 344 additions and 59 deletions

View File

@ -0,0 +1,27 @@
compute_environment: LOCAL_MACHINE
# gpu_ids: "0" # 0,1,2,3,4,5,6,7
# num_processes: 1
gpu_ids: all
num_processes: 8
debug: false
deepspeed_config:
deepspeed_config_file: /path/to/your/configs/zero2.yaml
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
num_machines: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,38 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,42 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,43 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto",
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e5
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,51 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto",
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,9 @@
from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer
from ..utils import register
class CogVideoX1dot5I2VSftTrainer(CogVideoXI2VSftTrainer):
pass
register("cogvideox1.5-i2v", "sft", CogVideoX1dot5I2VSftTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer
from ..utils import register
class CogVideoX1dot5T2VSftTrainer(CogVideoXT2VSftTrainer):
pass
register("cogvideox1.5-t2v", "sft", CogVideoX1dot5T2VSftTrainer)

View File

@ -0,0 +1,27 @@
import dotmap
from diffusers import CogVideoXImageToVideoPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
@override
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)

View File

@ -0,0 +1,27 @@
import dotmap
from diffusers import CogVideoXPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
@override
def initialize_pipeline(self) -> CogVideoXPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)

View File

@ -15,17 +15,12 @@ def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls
trainer_cls (Trainer): Trainer class to register.
"""
# Check if model_name exists in SUPPORTED_MODELS
# Check if model_name and training_type exists in SUPPORTED_MODELS
if model_name not in SUPPORTED_MODELS:
SUPPORTED_MODELS[model_name] = {}
else:
raise ValueError(f"Model {model_name} already exists")
# Check if training_type exists for this model
if training_type not in SUPPORTED_MODELS[model_name]:
SUPPORTED_MODELS[model_name][training_type] = {}
else:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
if training_type in SUPPORTED_MODELS[model_name]:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
SUPPORTED_MODELS[model_name][training_type] = trainer_cls

View File

@ -1,3 +1,4 @@
import hashlib
import json
import logging
import math
@ -71,7 +72,7 @@ class Trainer:
train_width=self.args.train_resolution[2],
)
self.components = Components()
self.components: Components = self.load_components()
self.accelerator: Accelerator = None
self.dataset: Dataset = None
self.data_loader: DataLoader = None
@ -145,9 +146,6 @@ class Trainer:
def prepare_models(self) -> None:
logger.info("Initializing models")
# Initialize model components
self.components = self.load_components()
if self.components.vae is not None:
if self.args.enable_slicing:
self.components.vae.enable_slicing()
@ -159,15 +157,11 @@ class Trainer:
def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader")
# self.state.train_frames includes one padding frame for image conditioning
# so we only sample train_frames - 1 frames from the actual video
sample_frames = self.state.train_frames - 1
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
max_num_frames=sample_frames,
max_num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self,
@ -176,7 +170,7 @@ class Trainer:
self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
max_num_frames=sample_frames,
max_num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self,
@ -223,12 +217,7 @@ class Trainer:
def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters")
# For now only lora is supported
for attr_name, component in vars(self.components).items():
if hasattr(component, "requires_grad_"):
component.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = self.state.weight_dtype
@ -238,35 +227,47 @@ class Trainer:
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
self.__load_components()
# For LoRA, we freeze all the parameters
# For SFT, we train all the parameters in transformer model
for attr_name, component in vars(self.components).items():
if hasattr(component, "requires_grad_"):
if self.args.training_type == "sft" and attr_name == "transformer":
component.requires_grad_(True)
else:
component.requires_grad_(False)
if self.args.training_type == "lora":
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config)
# Load components needed for training to GPU (except transformer),
# and cast them to the specified data type
self.__move_components_to_device(dtype=weight_dtype)
if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing()
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config)
def prepare_optimizer(self) -> None:
logger.info("Initializing optimizer and lr scheduler")
# Make sure the trainable params are in float32
if self.args.mixed_precision != "no":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
cast_training_params([self.components.transformer], dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
# For LoRA, we only want to train the LoRA weights
# For SFT, we want to train all the parameters
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
transformer_parameters_with_lr = {
"params": transformer_lora_parameters,
"params": trainable_parameters,
"lr": self.args.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr]
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
use_deepspeed_opt = (
self.accelerator.state.deepspeed_plugin is not None
@ -502,13 +503,15 @@ class Trainer:
# Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype)
#################################
all_processes_artifacts = []
for i in range(num_validation_samples):
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue
if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3:
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue
prompt = self.state.validation_prompts[i]
image = self.state.validation_images[i]
@ -534,7 +537,19 @@ class Trainer:
main_process_only=False,
)
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
if (
self.accelerator.deepspeed_plugin is not None
and self.accelerator.deepspeed_plugin.zero_stage == 3
and not accelerator.is_main_process
):
continue
prompt_filename = string_to_filename(prompt)[:25]
# Calculate hash of reversed prompt as a unique identifier
reversed_prompt = prompt[::-1]
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
artifacts = {
"image": {"type": "image", "value": image},
"video": {"type": "video", "value": video},
@ -553,7 +568,7 @@ class Trainer:
continue
extension = "png" if artifact_type == "image" else "mp4"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
validation_path = self.args.output_dir / "validation_res"
validation_path.mkdir(parents=True, exist_ok=True)
filename = str(validation_path / filename)
@ -587,11 +602,15 @@ class Trainer:
pipe.remove_all_hooks()
del pipe
# Unload models except those needed for training
self.__unload_components()
self.__move_components_to_cpu()
# Load models except those not needed for training
self.__load_components()
# Change LoRA weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
self.__move_components_to_device(dtype=self.state.weight_dtype)
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
if self.accelerator.state.deepspeed_plugin is None:
# Change trainable weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
accelerator.wait_for_everyone()
@ -649,16 +668,21 @@ class Trainer:
else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __load_components(self):
def __move_components_to_device(self, dtype):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
continue
# setattr(self.components, name, component.to(self.accelerator.device))
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
def __unload_components(self):
# We don't need to move transformer to device
# because we will prepare it in the `prepare_for_training()`
if name == "transformer":
continue
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
def __move_components_to_cpu(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
@ -723,13 +747,6 @@ class Trainer:
f" {unexpected_keys}. "
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if self.args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([transformer_])
self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook)