feat: support DeepSpeed ZeRO-3 and optimize peak memory usage

- Add DeepSpeed ZeRO-3 configuration support
- Optimize memory usage during training
- Rename training scripts to reflect ZeRO usage
- Update related configuration files and trainers
This commit is contained in:
OleehyO 2025-01-12 05:33:56 +00:00
parent 2f275e82b5
commit fdb9820949
7 changed files with 83 additions and 107 deletions

View File

@ -1,17 +1,11 @@
compute_environment: LOCAL_MACHINE
# gpu_ids: "0" # 0,1,2,3,4,5,6,7
# num_processes: 1
gpu_ids: all
num_processes: 8
gpu_ids: "0,1,2,4"
num_processes: 4 # should be the same as the number of GPUs
debug: false
deepspeed_config:
deepspeed_config_file: /path/to/your/configs/zero2.yaml
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
deepspeed_config_file: /absolute/path/to/your/deepspeed_config.yaml # e.g. /home/user/cogvideo/finetune/configs/zero2.yaml
zero3_init_flag: false
distributed_type: DEEPSPEED
downcast_bf16: 'no'

View File

@ -1,27 +1,9 @@
import dotmap
from diffusers import CogVideoXImageToVideoPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
@override
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
pass
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)

View File

@ -1,27 +1,9 @@
import dotmap
from diffusers import CogVideoXPipeline
from typing_extensions import override
from finetune.utils import unwrap_model
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
@override
def initialize_pipeline(self) -> CogVideoXPipeline:
origin_model = unwrap_model(self.accelerator, self.components.transformer)
self.components.transformer.config.update(origin_model.config)
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
pipe = CogVideoXPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=self.components.transformer,
scheduler=self.components.scheduler,
)
return pipe
pass
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)

View File

@ -8,13 +8,13 @@ from pydantic import BaseModel
class State(BaseModel):
model_config = {"arbitrary_types_allowed": True}
train_frames: int # user-defined training frames, **containing one image padding frame**
train_frames: int
train_height: int
train_width: int
transformer_config: Dict[str, Any] = None
weight_dtype: torch.dtype = torch.float32
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
num_trainable_parameters: int = 0
overwrote_max_train_steps: bool = False
num_update_steps_per_epoch: int = 0
@ -25,3 +25,5 @@ class State(BaseModel):
validation_prompts: List[str] = []
validation_images: List[Path | None] = []
validation_videos: List[Path | None] = []
using_deepspeed: bool = False

View File

@ -8,31 +8,34 @@ MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B-I2V"
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
--model_type "i2v"
--training_type "lora"
--training_type "sft"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/path/to/output/dir"
--output_dir "/absolute/path/to/your/output_dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/path/to/data/dir"
--data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt"
--video_column "videos.txt"
--image_column "images.txt"
--train_resolution "81x768x1360"
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10
--seed 42
######### Please keep consistent with deepspeed config file ##########
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"]
--seed 42
########################################################################
)
# System Configuration
@ -44,22 +47,23 @@ SYSTEM_ARGS=(
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 200
--checkpointing_steps 5
--checkpointing_limit 10
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation False
--validation_dir "/path/to/validation/dir"
--validation_steps 400
--do_validation false # ["true", "false"]
--validation_dir "/absolute/path/to/validation_set"
--validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt"
--validation_images "images.txt"
--gen_fps 16
)
# Combine all arguments and launch training
accelerate launch train.py \
accelerate launch --config_file accelerate_config.yaml train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \

View File

@ -8,30 +8,33 @@ MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B"
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
--model_type "t2v"
--training_type "lora"
--training_type "sft"
)
# Output Configuration
OUTPUT_ARGS=(
--output_dir "/path/to/output/dir"
--output_dir "/absolute/path/to/your/output_dir"
--report_to "tensorboard"
)
# Data Configuration
DATA_ARGS=(
--data_root "/path/to/data/dir"
--data_root "/absolute/path/to/your/data_root"
--caption_column "prompt.txt"
--video_column "videos.txt"
--train_resolution "81x768x1360"
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
)
# Training Configuration
TRAIN_ARGS=(
--train_epochs 10
--seed 42
######### Please keep consistent with deepspeed config file ##########
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16" # ["no", "fp16"]
--seed 42
########################################################################
)
# System Configuration
@ -43,21 +46,22 @@ SYSTEM_ARGS=(
# Checkpointing Configuration
CHECKPOINT_ARGS=(
--checkpointing_steps 200
--checkpointing_steps 5
--checkpointing_limit 10
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
)
# Validation Configuration
VALIDATION_ARGS=(
--do_validation False
--validation_dir "/path/to/validation/dir"
--validation_steps 400
--do_validation false # ["true", "false"]
--validation_dir "/absolute/path/to/validation_set"
--validation_steps 20 # should be multiple of checkpointing_steps
--validation_prompts "prompts.txt"
--gen_fps 16
)
# Combine all arguments and launch training
accelerate launch train.py \
accelerate launch --config_file accelerate_config.yaml train.py \
"${MODEL_ARGS[@]}" \
"${OUTPUT_ARGS[@]}" \
"${DATA_ARGS[@]}" \

View File

@ -84,6 +84,8 @@ class Trainer:
self._init_logging()
self._init_directories()
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
def _init_distributed(self):
logging_dir = Path(self.args.output_dir, "logs")
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
@ -246,9 +248,9 @@ class Trainer:
self.components.transformer.add_adapter(transformer_lora_config)
self.__prepare_saving_loading_hooks(transformer_lora_config)
# Load components needed for training to GPU (except transformer),
# and cast them to the specified data type
self.__move_components_to_device(dtype=weight_dtype)
# Load components needed for training to GPU (except transformer), and cast them to the specified data type
ignore_list = ["transformer"] + self.UNLOAD_LIST
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list)
if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing()
@ -406,6 +408,7 @@ class Trainer:
generator = generator.manual_seed(self.args.seed)
self.state.generator = generator
free_memory()
for epoch in range(first_epoch, self.args.train_epochs):
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
@ -497,18 +500,25 @@ class Trainer:
##### Initialize pipeline #####
pipe = self.initialize_pipeline()
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device)
if self.state.using_deepspeed:
# Can't using model_cpu_offload in deepspeed,
# so we need to move all components in pipe to device
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
else:
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device)
# Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype)
# Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype)
#################################
all_processes_artifacts = []
for i in range(num_validation_samples):
if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3:
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue
@ -539,7 +549,7 @@ class Trainer:
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
if (
self.accelerator.deepspeed_plugin is not None
self.state.using_deepspeed
and self.accelerator.deepspeed_plugin.zero_stage == 3
and not accelerator.is_main_process
):
@ -599,22 +609,25 @@ class Trainer:
step=step,
)
pipe.remove_all_hooks()
del pipe
# Unload models except those needed for training
self.__move_components_to_cpu()
########## Clean up ##########
if self.state.using_deepspeed:
del pipe
# Unload models except those needed for training
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
else:
pipe.remove_all_hooks()
del pipe
# Load models except those not needed for training
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
# Load models except those not needed for training
self.__move_components_to_device(dtype=self.state.weight_dtype)
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
if self.accelerator.state.deepspeed_plugin is None:
# Change trainable weights back to fp32
# Change trainable weights back to fp32 to keep with dtype after prepare the model
cast_training_params([self.components.transformer], dtype=torch.float32)
accelerator.wait_for_everyone()
free_memory()
accelerator.wait_for_everyone()
################################
memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
@ -668,25 +681,20 @@ class Trainer:
else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self, dtype):
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
ignore_list = set(ignore_list)
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
continue
if name not in ignore_list:
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
# We don't need to move transformer to device
# because we will prepare it in the `prepare_for_training()`
if name == "transformer":
continue
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
def __move_components_to_cpu(self):
def __move_components_to_cpu(self, unload_list: List[str] = []):
unload_list = set(unload_list)
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
if name in unload_list:
setattr(self.components, name, component.to("cpu"))
def __prepare_saving_loading_hooks(self, transformer_lora_config):