From fdb9820949ec0ef9424e0198aa66207c44021bd8 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sun, 12 Jan 2025 05:33:56 +0000 Subject: [PATCH] feat: support DeepSpeed ZeRO-3 and optimize peak memory usage - Add DeepSpeed ZeRO-3 configuration support - Optimize memory usage during training - Rename training scripts to reflect ZeRO usage - Update related configuration files and trainers --- finetune/accelerate_config.yaml | 14 +--- finetune/models/cogvideox_i2v/sft_trainer.py | 20 +---- finetune/models/cogvideox_t2v/sft_trainer.py | 20 +---- finetune/schemas/state.py | 6 +- ...elerate_train_i2v.sh => train_zero_i2v.sh} | 28 ++++--- ...elerate_train_t2v.sh => train_zero_t2v.sh} | 26 ++++--- finetune/trainer.py | 76 ++++++++++--------- 7 files changed, 83 insertions(+), 107 deletions(-) rename finetune/{accelerate_train_i2v.sh => train_zero_i2v.sh} (55%) rename finetune/{accelerate_train_t2v.sh => train_zero_t2v.sh} (57%) diff --git a/finetune/accelerate_config.yaml b/finetune/accelerate_config.yaml index ee3aed5..5d99cf4 100644 --- a/finetune/accelerate_config.yaml +++ b/finetune/accelerate_config.yaml @@ -1,17 +1,11 @@ compute_environment: LOCAL_MACHINE -# gpu_ids: "0" # 0,1,2,3,4,5,6,7 -# num_processes: 1 - -gpu_ids: all -num_processes: 8 +gpu_ids: "0,1,2,4" +num_processes: 4 # should be the same as the number of GPUs debug: false deepspeed_config: - deepspeed_config_file: /path/to/your/configs/zero2.yaml - # deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml - # deepspeed_config_file: /path/to/your/configs/zero3.yaml - # deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml + deepspeed_config_file: /absolute/path/to/your/deepspeed_config.yaml # e.g. /home/user/cogvideo/finetune/configs/zero2.yaml zero3_init_flag: false distributed_type: DEEPSPEED downcast_bf16: 'no' @@ -24,4 +18,4 @@ same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false -use_cpu: false +use_cpu: false \ No newline at end of file diff --git a/finetune/models/cogvideox_i2v/sft_trainer.py b/finetune/models/cogvideox_i2v/sft_trainer.py index 0cb5f12..b55bee8 100644 --- a/finetune/models/cogvideox_i2v/sft_trainer.py +++ b/finetune/models/cogvideox_i2v/sft_trainer.py @@ -1,27 +1,9 @@ -import dotmap -from diffusers import CogVideoXImageToVideoPipeline -from typing_extensions import override - -from finetune.utils import unwrap_model - from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer from ..utils import register class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer): - @override - def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline: - origin_model = unwrap_model(self.accelerator, self.components.transformer) - self.components.transformer.config.update(origin_model.config) - self.components.transformer.config = dotmap.DotMap(self.components.transformer.config) - pipe = CogVideoXImageToVideoPipeline( - tokenizer=self.components.tokenizer, - text_encoder=self.components.text_encoder, - vae=self.components.vae, - transformer=self.components.transformer, - scheduler=self.components.scheduler, - ) - return pipe + pass register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer) diff --git a/finetune/models/cogvideox_t2v/sft_trainer.py b/finetune/models/cogvideox_t2v/sft_trainer.py index 85b80fa..12b239e 100644 --- a/finetune/models/cogvideox_t2v/sft_trainer.py +++ b/finetune/models/cogvideox_t2v/sft_trainer.py @@ -1,27 +1,9 @@ -import dotmap -from diffusers import CogVideoXPipeline -from typing_extensions import override - -from finetune.utils import unwrap_model - from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer from ..utils import register class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer): - @override - def initialize_pipeline(self) -> CogVideoXPipeline: - origin_model = unwrap_model(self.accelerator, self.components.transformer) - self.components.transformer.config.update(origin_model.config) - self.components.transformer.config = dotmap.DotMap(self.components.transformer.config) - pipe = CogVideoXPipeline( - tokenizer=self.components.tokenizer, - text_encoder=self.components.text_encoder, - vae=self.components.vae, - transformer=self.components.transformer, - scheduler=self.components.scheduler, - ) - return pipe + pass register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer) diff --git a/finetune/schemas/state.py b/finetune/schemas/state.py index 315185c..36fa0fa 100644 --- a/finetune/schemas/state.py +++ b/finetune/schemas/state.py @@ -8,13 +8,13 @@ from pydantic import BaseModel class State(BaseModel): model_config = {"arbitrary_types_allowed": True} - train_frames: int # user-defined training frames, **containing one image padding frame** + train_frames: int train_height: int train_width: int transformer_config: Dict[str, Any] = None - weight_dtype: torch.dtype = torch.float32 + weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training num_trainable_parameters: int = 0 overwrote_max_train_steps: bool = False num_update_steps_per_epoch: int = 0 @@ -25,3 +25,5 @@ class State(BaseModel): validation_prompts: List[str] = [] validation_images: List[Path | None] = [] validation_videos: List[Path | None] = [] + + using_deepspeed: bool = False diff --git a/finetune/accelerate_train_i2v.sh b/finetune/train_zero_i2v.sh similarity index 55% rename from finetune/accelerate_train_i2v.sh rename to finetune/train_zero_i2v.sh index fe6280c..036761e 100644 --- a/finetune/accelerate_train_i2v.sh +++ b/finetune/train_zero_i2v.sh @@ -8,31 +8,34 @@ MODEL_ARGS=( --model_path "THUDM/CogVideoX1.5-5B-I2V" --model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"] --model_type "i2v" - --training_type "lora" + --training_type "sft" ) # Output Configuration OUTPUT_ARGS=( - --output_dir "/path/to/output/dir" + --output_dir "/absolute/path/to/your/output_dir" --report_to "tensorboard" ) # Data Configuration DATA_ARGS=( - --data_root "/path/to/data/dir" + --data_root "/absolute/path/to/your/data_root" --caption_column "prompt.txt" --video_column "videos.txt" - --image_column "images.txt" - --train_resolution "81x768x1360" + # --image_column "images.txt" # comment this line will use first frame of video as image conditioning + --train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 ) # Training Configuration TRAIN_ARGS=( --train_epochs 10 + --seed 42 + + ######### Please keep consistent with deepspeed config file ########## --batch_size 1 --gradient_accumulation_steps 1 --mixed_precision "bf16" # ["no", "fp16"] - --seed 42 + ######################################################################## ) # System Configuration @@ -44,26 +47,27 @@ SYSTEM_ARGS=( # Checkpointing Configuration CHECKPOINT_ARGS=( - --checkpointing_steps 200 + --checkpointing_steps 5 --checkpointing_limit 10 + --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line ) # Validation Configuration VALIDATION_ARGS=( - --do_validation False - --validation_dir "/path/to/validation/dir" - --validation_steps 400 + --do_validation false # ["true", "false"] + --validation_dir "/absolute/path/to/validation_set" + --validation_steps 20 # should be multiple of checkpointing_steps --validation_prompts "prompts.txt" --validation_images "images.txt" --gen_fps 16 ) # Combine all arguments and launch training -accelerate launch train.py \ +accelerate launch --config_file accelerate_config.yaml train.py \ "${MODEL_ARGS[@]}" \ "${OUTPUT_ARGS[@]}" \ "${DATA_ARGS[@]}" \ "${TRAIN_ARGS[@]}" \ "${SYSTEM_ARGS[@]}" \ "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" \ No newline at end of file + "${VALIDATION_ARGS[@]}" diff --git a/finetune/accelerate_train_t2v.sh b/finetune/train_zero_t2v.sh similarity index 57% rename from finetune/accelerate_train_t2v.sh rename to finetune/train_zero_t2v.sh index ce2c2bd..75fb1d6 100644 --- a/finetune/accelerate_train_t2v.sh +++ b/finetune/train_zero_t2v.sh @@ -8,30 +8,33 @@ MODEL_ARGS=( --model_path "THUDM/CogVideoX1.5-5B" --model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"] --model_type "t2v" - --training_type "lora" + --training_type "sft" ) # Output Configuration OUTPUT_ARGS=( - --output_dir "/path/to/output/dir" + --output_dir "/absolute/path/to/your/output_dir" --report_to "tensorboard" ) # Data Configuration DATA_ARGS=( - --data_root "/path/to/data/dir" + --data_root "/absolute/path/to/your/data_root" --caption_column "prompt.txt" --video_column "videos.txt" - --train_resolution "81x768x1360" + --train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 ) # Training Configuration TRAIN_ARGS=( --train_epochs 10 + --seed 42 + + ######### Please keep consistent with deepspeed config file ########## --batch_size 1 --gradient_accumulation_steps 1 --mixed_precision "bf16" # ["no", "fp16"] - --seed 42 + ######################################################################## ) # System Configuration @@ -43,25 +46,26 @@ SYSTEM_ARGS=( # Checkpointing Configuration CHECKPOINT_ARGS=( - --checkpointing_steps 200 + --checkpointing_steps 5 --checkpointing_limit 10 + --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line ) # Validation Configuration VALIDATION_ARGS=( - --do_validation False - --validation_dir "/path/to/validation/dir" - --validation_steps 400 + --do_validation false # ["true", "false"] + --validation_dir "/absolute/path/to/validation_set" + --validation_steps 20 # should be multiple of checkpointing_steps --validation_prompts "prompts.txt" --gen_fps 16 ) # Combine all arguments and launch training -accelerate launch train.py \ +accelerate launch --config_file accelerate_config.yaml train.py \ "${MODEL_ARGS[@]}" \ "${OUTPUT_ARGS[@]}" \ "${DATA_ARGS[@]}" \ "${TRAIN_ARGS[@]}" \ "${SYSTEM_ARGS[@]}" \ "${CHECKPOINT_ARGS[@]}" \ - "${VALIDATION_ARGS[@]}" \ No newline at end of file + "${VALIDATION_ARGS[@]}" diff --git a/finetune/trainer.py b/finetune/trainer.py index e3ec8e6..ee808f9 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -84,6 +84,8 @@ class Trainer: self._init_logging() self._init_directories() + self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None + def _init_distributed(self): logging_dir = Path(self.args.output_dir, "logs") project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) @@ -246,9 +248,9 @@ class Trainer: self.components.transformer.add_adapter(transformer_lora_config) self.__prepare_saving_loading_hooks(transformer_lora_config) - # Load components needed for training to GPU (except transformer), - # and cast them to the specified data type - self.__move_components_to_device(dtype=weight_dtype) + # Load components needed for training to GPU (except transformer), and cast them to the specified data type + ignore_list = ["transformer"] + self.UNLOAD_LIST + self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list) if self.args.gradient_checkpointing: self.components.transformer.enable_gradient_checkpointing() @@ -406,6 +408,7 @@ class Trainer: generator = generator.manual_seed(self.args.seed) self.state.generator = generator + free_memory() for epoch in range(first_epoch, self.args.train_epochs): logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") @@ -497,18 +500,25 @@ class Trainer: ##### Initialize pipeline ##### pipe = self.initialize_pipeline() - # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage - pipe.enable_model_cpu_offload(device=self.accelerator.device) + if self.state.using_deepspeed: + # Can't using model_cpu_offload in deepspeed, + # so we need to move all components in pipe to device + # pipe.to(self.accelerator.device, dtype=self.state.weight_dtype) + self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"]) + else: + # if not using deepspeed, use model_cpu_offload to further reduce memory usage + # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage + pipe.enable_model_cpu_offload(device=self.accelerator.device) - # Convert all model weights to training dtype - # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 - pipe = pipe.to(dtype=self.state.weight_dtype) + # Convert all model weights to training dtype + # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 + pipe = pipe.to(dtype=self.state.weight_dtype) ################################# all_processes_artifacts = [] for i in range(num_validation_samples): - if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3: + if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3: # Skip current validation on all processes but one if i % accelerator.num_processes != accelerator.process_index: continue @@ -539,7 +549,7 @@ class Trainer: validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) if ( - self.accelerator.deepspeed_plugin is not None + self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage == 3 and not accelerator.is_main_process ): @@ -599,22 +609,25 @@ class Trainer: step=step, ) - pipe.remove_all_hooks() - del pipe - # Unload models except those needed for training - self.__move_components_to_cpu() + ########## Clean up ########## + if self.state.using_deepspeed: + del pipe + # Unload models except those needed for training + self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST) + else: + pipe.remove_all_hooks() + del pipe + # Load models except those not needed for training + self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) + self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) - # Load models except those not needed for training - self.__move_components_to_device(dtype=self.state.weight_dtype) - self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) - - if self.accelerator.state.deepspeed_plugin is None: - # Change trainable weights back to fp32 + # Change trainable weights back to fp32 to keep with dtype after prepare the model cast_training_params([self.components.transformer], dtype=torch.float32) - accelerator.wait_for_everyone() - free_memory() + accelerator.wait_for_everyone() + ################################ + memory_statistics = get_memory_statistics() logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") torch.cuda.reset_peak_memory_stats(accelerator.device) @@ -668,25 +681,20 @@ class Trainer: else: raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") - def __move_components_to_device(self, dtype): + def __move_components_to_device(self, dtype, ignore_list: List[str] = []): + ignore_list = set(ignore_list) components = self.components.model_dump() for name, component in components.items(): if not isinstance(component, type) and hasattr(component, "to"): - if name in self.UNLOAD_LIST: - continue + if name not in ignore_list: + setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) - # We don't need to move transformer to device - # because we will prepare it in the `prepare_for_training()` - if name == "transformer": - continue - - setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) - - def __move_components_to_cpu(self): + def __move_components_to_cpu(self, unload_list: List[str] = []): + unload_list = set(unload_list) components = self.components.model_dump() for name, component in components.items(): if not isinstance(component, type) and hasattr(component, "to"): - if name in self.UNLOAD_LIST: + if name in unload_list: setattr(self.components, name, component.to("cpu")) def __prepare_saving_loading_hooks(self, transformer_lora_config):