feat: support DeepSpeed ZeRO-3 and optimize peak memory usage

- Add DeepSpeed ZeRO-3 configuration support
- Optimize memory usage during training
- Rename training scripts to reflect ZeRO usage
- Update related configuration files and trainers
This commit is contained in:
OleehyO 2025-01-12 05:33:56 +00:00
parent 2f275e82b5
commit fdb9820949
7 changed files with 83 additions and 107 deletions

View File

@ -1,17 +1,11 @@
compute_environment: LOCAL_MACHINE compute_environment: LOCAL_MACHINE
# gpu_ids: "0" # 0,1,2,3,4,5,6,7 gpu_ids: "0,1,2,4"
# num_processes: 1 num_processes: 4 # should be the same as the number of GPUs
gpu_ids: all
num_processes: 8
debug: false debug: false
deepspeed_config: deepspeed_config:
deepspeed_config_file: /path/to/your/configs/zero2.yaml deepspeed_config_file: /absolute/path/to/your/deepspeed_config.yaml # e.g. /home/user/cogvideo/finetune/configs/zero2.yaml
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
zero3_init_flag: false zero3_init_flag: false
distributed_type: DEEPSPEED distributed_type: DEEPSPEED
downcast_bf16: 'no' downcast_bf16: 'no'

View File

@ -1,27 +1,9 @@
import dotmap
from diffusers import CogVideoXImageToVideoPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register from ..utils import register
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer): class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
@override pass
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer) register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)

View File

@ -1,27 +1,9 @@
import dotmap
from diffusers import CogVideoXPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register from ..utils import register
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer): class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
@override pass
def initialize_pipeline(self) -> CogVideoXPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer) register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)

View File

@ -8,13 +8,13 @@ from pydantic import BaseModel
class State(BaseModel): class State(BaseModel):
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
train_frames: int # user-defined training frames, **containing one image padding frame** train_frames: int
train_height: int train_height: int
train_width: int train_width: int
transformer_config: Dict[str, Any] = None transformer_config: Dict[str, Any] = None
weight_dtype: torch.dtype = torch.float32 weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
num_trainable_parameters: int = 0 num_trainable_parameters: int = 0
overwrote_max_train_steps: bool = False overwrote_max_train_steps: bool = False
num_update_steps_per_epoch: int = 0 num_update_steps_per_epoch: int = 0
@ -25,3 +25,5 @@ class State(BaseModel):
validation_prompts: List[str] = [] validation_prompts: List[str] = []
validation_images: List[Path | None] = [] validation_images: List[Path | None] = []
validation_videos: List[Path | None] = [] validation_videos: List[Path | None] = []
using_deepspeed: bool = False

View File

@ -8,31 +8,34 @@ MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B-I2V" --model_path "THUDM/CogVideoX1.5-5B-I2V"
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"] --model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
--model_type "i2v" --model_type "i2v"
--training_type "lora" --training_type "sft"
) )
# Output Configuration # Output Configuration
OUTPUT_ARGS=( OUTPUT_ARGS=(
--output_dir "/path/to/output/dir" --output_dir "/absolute/path/to/your/output_dir"
--report_to "tensorboard" --report_to "tensorboard"
) )
# Data Configuration # Data Configuration
DATA_ARGS=( DATA_ARGS=(
--data_root "/path/to/data/dir" --data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt" --caption_column "prompt.txt"
--video_column "videos.txt" --video_column "videos.txt"
--image_column "images.txt" # --image_column "images.txt" # comment this line will use first frame of video as image conditioning
--train_resolution "81x768x1360" --train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
) )
# Training Configuration # Training Configuration
TRAIN_ARGS=( TRAIN_ARGS=(
--train_epochs 10 --train_epochs 10
--seed 42
######### Please keep consistent with deepspeed config file ##########
--batch_size 1 --batch_size 1
--gradient_accumulation_steps 1 --gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"] --mixed_precision "bf16" # ["no", "fp16"]
--seed 42 ########################################################################
) )
# System Configuration # System Configuration
@ -44,22 +47,23 @@ SYSTEM_ARGS=(
# Checkpointing Configuration # Checkpointing Configuration
CHECKPOINT_ARGS=( CHECKPOINT_ARGS=(
--checkpointing_steps 200 --checkpointing_steps 5
--checkpointing_limit 10 --checkpointing_limit 10
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
) )
# Validation Configuration # Validation Configuration
VALIDATION_ARGS=( VALIDATION_ARGS=(
--do_validation False --do_validation false # ["true", "false"]
--validation_dir "/path/to/validation/dir" --validation_dir "/absolute/path/to/validation_set"
--validation_steps 400 --validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt" --validation_prompts "prompts.txt"
--validation_images "images.txt" --validation_images "images.txt"
--gen_fps 16 --gen_fps 16
) )
# Combine all arguments and launch training # Combine all arguments and launch training
accelerate launch train.py \ accelerate launch --config_file accelerate_config.yaml train.py \
"${MODEL_ARGS[@]}" \ "${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \ "${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \ "${DATA_ARGS[@]}" \

View File

@ -8,30 +8,33 @@ MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B" --model_path "THUDM/CogVideoX1.5-5B"
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"] --model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
--model_type "t2v" --model_type "t2v"
--training_type "lora" --training_type "sft"
) )
# Output Configuration # Output Configuration
OUTPUT_ARGS=( OUTPUT_ARGS=(
--output_dir "/path/to/output/dir" --output_dir "/absolute/path/to/your/output_dir"
--report_to "tensorboard" --report_to "tensorboard"
) )
# Data Configuration # Data Configuration
DATA_ARGS=( DATA_ARGS=(
--data_root "/path/to/data/dir" --data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt" --caption_column "prompt.txt"
--video_column "videos.txt" --video_column "videos.txt"
--train_resolution "81x768x1360" --train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
) )
# Training Configuration # Training Configuration
TRAIN_ARGS=( TRAIN_ARGS=(
--train_epochs 10 --train_epochs 10
--seed 42
######### Please keep consistent with deepspeed config file ##########
--batch_size 1 --batch_size 1
--gradient_accumulation_steps 1 --gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"] --mixed_precision "bf16" # ["no", "fp16"]
--seed 42 ########################################################################
) )
# System Configuration # System Configuration
@ -43,21 +46,22 @@ SYSTEM_ARGS=(
# Checkpointing Configuration # Checkpointing Configuration
CHECKPOINT_ARGS=( CHECKPOINT_ARGS=(
--checkpointing_steps 200 --checkpointing_steps 5
--checkpointing_limit 10 --checkpointing_limit 10
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
) )
# Validation Configuration # Validation Configuration
VALIDATION_ARGS=( VALIDATION_ARGS=(
--do_validation False --do_validation false # ["true", "false"]
--validation_dir "/path/to/validation/dir" --validation_dir "/absolute/path/to/validation_set"
--validation_steps 400 --validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt" --validation_prompts "prompts.txt"
--gen_fps 16 --gen_fps 16
) )
# Combine all arguments and launch training # Combine all arguments and launch training
accelerate launch train.py \ accelerate launch --config_file accelerate_config.yaml train.py \
"${MODEL_ARGS[@]}" \ "${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \ "${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \ "${DATA_ARGS[@]}" \

View File

@ -84,6 +84,8 @@ class Trainer:
self._init_logging() self._init_logging()
self._init_directories() self._init_directories()
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
def _init_distributed(self): def _init_distributed(self):
logging_dir = Path(self.args.output_dir, "logs") logging_dir = Path(self.args.output_dir, "logs")
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
@ -246,9 +248,9 @@ class Trainer:
self.components.transformer.add_adapter(transformer_lora_config) self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config) self.__prepare_saving_loading_hooks(transformer_lora_config)
# Load components needed for training to GPU (except transformer), # Load components needed for training to GPU (except transformer), and cast them to the specified data type
# and cast them to the specified data type ignore_list = ["transformer"] + self.UNLOAD_LIST
self.__move_components_to_device(dtype=weight_dtype) self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list)
if self.args.gradient_checkpointing: if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing() self.components.transformer.enable_gradient_checkpointing()
@ -406,6 +408,7 @@ class Trainer:
generator = generator.manual_seed(self.args.seed) generator = generator.manual_seed(self.args.seed)
self.state.generator = generator self.state.generator = generator
free_memory()
for epoch in range(first_epoch, self.args.train_epochs): for epoch in range(first_epoch, self.args.train_epochs):
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
@ -497,6 +500,13 @@ class Trainer:
##### Initialize pipeline ##### ##### Initialize pipeline #####
pipe = self.initialize_pipeline() pipe = self.initialize_pipeline()
if self.state.using_deepspeed:
# Can't using model_cpu_offload in deepspeed,
# so we need to move all components in pipe to device
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
else:
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device) pipe.enable_model_cpu_offload(device=self.accelerator.device)
@ -508,7 +518,7 @@ class Trainer:
all_processes_artifacts = [] all_processes_artifacts = []
for i in range(num_validation_samples): for i in range(num_validation_samples):
if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3: if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
# Skip current validation on all processes but one # Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index: if i % accelerator.num_processes != accelerator.process_index:
continue continue
@ -539,7 +549,7 @@ class Trainer:
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
if ( if (
self.accelerator.deepspeed_plugin is not None self.state.using_deepspeed
and self.accelerator.deepspeed_plugin.zero_stage == 3 and self.accelerator.deepspeed_plugin.zero_stage == 3
and not accelerator.is_main_process and not accelerator.is_main_process
): ):
@ -599,22 +609,25 @@ class Trainer:
step=step, step=step,
) )
pipe.remove_all_hooks() ########## Clean up ##########
if self.state.using_deepspeed:
del pipe del pipe
# Unload models except those needed for training # Unload models except those needed for training
self.__move_components_to_cpu() self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
else:
pipe.remove_all_hooks()
del pipe
# Load models except those not needed for training # Load models except those not needed for training
self.__move_components_to_device(dtype=self.state.weight_dtype) self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
if self.accelerator.state.deepspeed_plugin is None: # Change trainable weights back to fp32 to keep with dtype after prepare the model
# Change trainable weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32) cast_training_params([self.components.transformer], dtype=torch.float32)
accelerator.wait_for_everyone()
free_memory() free_memory()
accelerator.wait_for_everyone()
################################
memory_statistics = get_memory_statistics() memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device) torch.cuda.reset_peak_memory_stats(accelerator.device)
@ -668,25 +681,20 @@ class Trainer:
else: else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self, dtype): def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
ignore_list = set(ignore_list)
components = self.components.model_dump() components = self.components.model_dump()
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"): if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST: if name not in ignore_list:
continue
# We don't need to move transformer to device
# because we will prepare it in the `prepare_for_training()`
if name == "transformer":
continue
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
def __move_components_to_cpu(self): def __move_components_to_cpu(self, unload_list: List[str] = []):
unload_list = set(unload_list)
components = self.components.model_dump() components = self.components.model_dump()
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"): if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST: if name in unload_list:
setattr(self.components, name, component.to("cpu")) setattr(self.components, name, component.to("cpu"))
def __prepare_saving_loading_hooks(self, transformer_lora_config): def __prepare_saving_loading_hooks(self, transformer_lora_config):