feat: add SFT support with ZeRO optimization strategies

- Add SFT (Supervised Fine-Tuning) trainers for all model variants:
  - CogVideoX I2V and T2V
  - CogVideoX-1.5 I2V and T2V
- Add DeepSpeed ZeRO configuration files:
  - ZeRO-2 with and without CPU offload
  - ZeRO-3 with and without CPU offload
- Add base accelerate config for distributed training
- Update trainer.py to support SFT training mode

This enables full-parameter fine-tuning with memory-efficient distributed training using DeepSpeed ZeRO optimization.
This commit is contained in:
OleehyO 2025-01-11 02:13:32 +00:00
parent e213b6c083
commit caa24bdc36
11 changed files with 344 additions and 59 deletions

View File

@ -0,0 +1,27 @@
compute_environment: LOCAL_MACHINE
# gpu_ids: "0" # 0,1,2,3,4,5,6,7
# num_processes: 1
gpu_ids: all
num_processes: 8
debug: false
deepspeed_config:
deepspeed_config_file: /path/to/your/configs/zero2.yaml
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
num_machines: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,38 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,42 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,43 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto",
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e5
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,51 @@
{
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"sub_group_size": 1e9,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": "auto",
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,9 @@
from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer
from ..utils import register
class CogVideoX1dot5I2VSftTrainer(CogVideoXI2VSftTrainer):
pass
register("cogvideox1.5-i2v", "sft", CogVideoX1dot5I2VSftTrainer)

View File

@ -0,0 +1,9 @@
from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer
from ..utils import register
class CogVideoX1dot5T2VSftTrainer(CogVideoXT2VSftTrainer):
pass
register("cogvideox1.5-t2v", "sft", CogVideoX1dot5T2VSftTrainer)

View File

@ -0,0 +1,27 @@
import dotmap
from diffusers import CogVideoXImageToVideoPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
@override
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)

View File

@ -0,0 +1,27 @@
import dotmap
from diffusers import CogVideoXPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
@override
def initialize_pipeline(self) -> CogVideoXPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)

View File

@ -15,17 +15,12 @@ def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls
trainer_cls (Trainer): Trainer class to register. trainer_cls (Trainer): Trainer class to register.
""" """
# Check if model_name exists in SUPPORTED_MODELS # Check if model_name and training_type exists in SUPPORTED_MODELS
if model_name not in SUPPORTED_MODELS: if model_name not in SUPPORTED_MODELS:
SUPPORTED_MODELS[model_name] = {} SUPPORTED_MODELS[model_name] = {}
else: else:
raise ValueError(f"Model {model_name} already exists") if training_type in SUPPORTED_MODELS[model_name]:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
# Check if training_type exists for this model
if training_type not in SUPPORTED_MODELS[model_name]:
SUPPORTED_MODELS[model_name][training_type] = {}
else:
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
SUPPORTED_MODELS[model_name][training_type] = trainer_cls SUPPORTED_MODELS[model_name][training_type] = trainer_cls

View File

@ -1,3 +1,4 @@
import hashlib
import json import json
import logging import logging
import math import math
@ -71,7 +72,7 @@ class Trainer:
train_width=self.args.train_resolution[2], train_width=self.args.train_resolution[2],
) )
self.components = Components() self.components: Components = self.load_components()
self.accelerator: Accelerator = None self.accelerator: Accelerator = None
self.dataset: Dataset = None self.dataset: Dataset = None
self.data_loader: DataLoader = None self.data_loader: DataLoader = None
@ -145,9 +146,6 @@ class Trainer:
def prepare_models(self) -> None: def prepare_models(self) -> None:
logger.info("Initializing models") logger.info("Initializing models")
# Initialize model components
self.components = self.load_components()
if self.components.vae is not None: if self.components.vae is not None:
if self.args.enable_slicing: if self.args.enable_slicing:
self.components.vae.enable_slicing() self.components.vae.enable_slicing()
@ -159,15 +157,11 @@ class Trainer:
def prepare_dataset(self) -> None: def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader") logger.info("Initializing dataset and dataloader")
# self.state.train_frames includes one padding frame for image conditioning
# so we only sample train_frames - 1 frames from the actual video
sample_frames = self.state.train_frames - 1
if self.args.model_type == "i2v": if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize( self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()), **(self.args.model_dump()),
device=self.accelerator.device, device=self.accelerator.device,
max_num_frames=sample_frames, max_num_frames=self.state.train_frames,
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width, width=self.state.train_width,
trainer=self, trainer=self,
@ -176,7 +170,7 @@ class Trainer:
self.dataset = T2VDatasetWithResize( self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()), **(self.args.model_dump()),
device=self.accelerator.device, device=self.accelerator.device,
max_num_frames=sample_frames, max_num_frames=self.state.train_frames,
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width, width=self.state.train_width,
trainer=self, trainer=self,
@ -223,12 +217,7 @@ class Trainer:
def prepare_trainable_parameters(self): def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters") logger.info("Initializing trainable parameters")
# For now only lora is supported # For mixed precision training we cast all non-trainable weights to half-precision
for attr_name, component in vars(self.components).items():
if hasattr(component, "requires_grad_"):
component.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = self.state.weight_dtype weight_dtype = self.state.weight_dtype
@ -238,35 +227,47 @@ class Trainer:
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
) )
self.__load_components() # For LoRA, we freeze all the parameters
# For SFT, we train all the parameters in transformer model
for attr_name, component in vars(self.components).items():
if hasattr(component, "requires_grad_"):
if self.args.training_type == "sft" and attr_name == "transformer":
component.requires_grad_(True)
else:
component.requires_grad_(False)
if self.args.training_type == "lora":
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config)
# Load components needed for training to GPU (except transformer),
# and cast them to the specified data type
self.__move_components_to_device(dtype=weight_dtype)
if self.args.gradient_checkpointing: if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing() self.components.transformer.enable_gradient_checkpointing()
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config)
def prepare_optimizer(self) -> None: def prepare_optimizer(self) -> None:
logger.info("Initializing optimizer and lr scheduler") logger.info("Initializing optimizer and lr scheduler")
# Make sure the trainable params are in float32 # Make sure the trainable params are in float32
if self.args.mixed_precision != "no": cast_training_params([self.components.transformer], dtype=torch.float32)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) # For LoRA, we only want to train the LoRA weights
# For SFT, we want to train all the parameters
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
transformer_parameters_with_lr = { transformer_parameters_with_lr = {
"params": transformer_lora_parameters, "params": trainable_parameters,
"lr": self.args.learning_rate, "lr": self.args.learning_rate,
} }
params_to_optimize = [transformer_parameters_with_lr] params_to_optimize = [transformer_parameters_with_lr]
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters) self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
use_deepspeed_opt = ( use_deepspeed_opt = (
self.accelerator.state.deepspeed_plugin is not None self.accelerator.state.deepspeed_plugin is not None
@ -502,13 +503,15 @@ class Trainer:
# Convert all model weights to training dtype # Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype) pipe = pipe.to(dtype=self.state.weight_dtype)
################################# #################################
all_processes_artifacts = [] all_processes_artifacts = []
for i in range(num_validation_samples): for i in range(num_validation_samples):
# Skip current validation on all processes but one if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3:
if i % accelerator.num_processes != accelerator.process_index: # Skip current validation on all processes but one
continue if i % accelerator.num_processes != accelerator.process_index:
continue
prompt = self.state.validation_prompts[i] prompt = self.state.validation_prompts[i]
image = self.state.validation_images[i] image = self.state.validation_images[i]
@ -534,7 +537,19 @@ class Trainer:
main_process_only=False, main_process_only=False,
) )
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
if (
self.accelerator.deepspeed_plugin is not None
and self.accelerator.deepspeed_plugin.zero_stage == 3
and not accelerator.is_main_process
):
continue
prompt_filename = string_to_filename(prompt)[:25] prompt_filename = string_to_filename(prompt)[:25]
# Calculate hash of reversed prompt as a unique identifier
reversed_prompt = prompt[::-1]
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
artifacts = { artifacts = {
"image": {"type": "image", "value": image}, "image": {"type": "image", "value": image},
"video": {"type": "video", "value": video}, "video": {"type": "video", "value": video},
@ -553,7 +568,7 @@ class Trainer:
continue continue
extension = "png" if artifact_type == "image" else "mp4" extension = "png" if artifact_type == "image" else "mp4"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}" filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
validation_path = self.args.output_dir / "validation_res" validation_path = self.args.output_dir / "validation_res"
validation_path.mkdir(parents=True, exist_ok=True) validation_path.mkdir(parents=True, exist_ok=True)
filename = str(validation_path / filename) filename = str(validation_path / filename)
@ -587,11 +602,15 @@ class Trainer:
pipe.remove_all_hooks() pipe.remove_all_hooks()
del pipe del pipe
# Unload models except those needed for training # Unload models except those needed for training
self.__unload_components() self.__move_components_to_cpu()
# Load models except those not needed for training # Load models except those not needed for training
self.__load_components() self.__move_components_to_device(dtype=self.state.weight_dtype)
# Change LoRA weights back to fp32 self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
cast_training_params([self.components.transformer], dtype=torch.float32)
if self.accelerator.state.deepspeed_plugin is None:
# Change trainable weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -649,16 +668,21 @@ class Trainer:
else: else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __load_components(self): def __move_components_to_device(self, dtype):
components = self.components.model_dump() components = self.components.model_dump()
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"): if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST: if name in self.UNLOAD_LIST:
continue continue
# setattr(self.components, name, component.to(self.accelerator.device))
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
def __unload_components(self): # We don't need to move transformer to device
# because we will prepare it in the `prepare_for_training()`
if name == "transformer":
continue
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
def __move_components_to_cpu(self):
components = self.components.model_dump() components = self.components.model_dump()
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"): if not isinstance(component, type) and hasattr(component, "to"):
@ -723,13 +747,6 @@ class Trainer:
f" {unexpected_keys}. " f" {unexpected_keys}. "
) )
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if self.args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([transformer_])
self.accelerator.register_save_state_pre_hook(save_model_hook) self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook)