mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat: support DeepSpeed ZeRO-3 and optimize peak memory usage
- Add DeepSpeed ZeRO-3 configuration support - Optimize memory usage during training - Rename training scripts to reflect ZeRO usage - Update related configuration files and trainers
This commit is contained in:
parent
2f275e82b5
commit
fdb9820949
@ -1,17 +1,11 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
|
||||
# gpu_ids: "0" # 0,1,2,3,4,5,6,7
|
||||
# num_processes: 1
|
||||
|
||||
gpu_ids: all
|
||||
num_processes: 8
|
||||
gpu_ids: "0,1,2,4"
|
||||
num_processes: 4 # should be the same as the number of GPUs
|
||||
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: /path/to/your/configs/zero2.yaml
|
||||
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
|
||||
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
|
||||
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
|
||||
deepspeed_config_file: /absolute/path/to/your/deepspeed_config.yaml # e.g. /home/user/cogvideo/finetune/configs/zero2.yaml
|
||||
zero3_init_flag: false
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
|
@ -1,27 +1,9 @@
|
||||
import dotmap
|
||||
from diffusers import CogVideoXImageToVideoPipeline
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
||||
@override
|
||||
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
||||
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
||||
self.components.transformer.config.update(origin_model.config)
|
||||
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
||||
pipe = CogVideoXImageToVideoPipeline(
|
||||
tokenizer=self.components.tokenizer,
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=self.components.transformer,
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
||||
|
@ -1,27 +1,9 @@
|
||||
import dotmap
|
||||
from diffusers import CogVideoXPipeline
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
||||
@override
|
||||
def initialize_pipeline(self) -> CogVideoXPipeline:
|
||||
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
||||
self.components.transformer.config.update(origin_model.config)
|
||||
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
||||
pipe = CogVideoXPipeline(
|
||||
tokenizer=self.components.tokenizer,
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=self.components.transformer,
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|
||||
|
@ -8,13 +8,13 @@ from pydantic import BaseModel
|
||||
class State(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
train_frames: int # user-defined training frames, **containing one image padding frame**
|
||||
train_frames: int
|
||||
train_height: int
|
||||
train_width: int
|
||||
|
||||
transformer_config: Dict[str, Any] = None
|
||||
|
||||
weight_dtype: torch.dtype = torch.float32
|
||||
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
|
||||
num_trainable_parameters: int = 0
|
||||
overwrote_max_train_steps: bool = False
|
||||
num_update_steps_per_epoch: int = 0
|
||||
@ -25,3 +25,5 @@ class State(BaseModel):
|
||||
validation_prompts: List[str] = []
|
||||
validation_images: List[Path | None] = []
|
||||
validation_videos: List[Path | None] = []
|
||||
|
||||
using_deepspeed: bool = False
|
||||
|
@ -8,31 +8,34 @@ MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B-I2V"
|
||||
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
|
||||
--model_type "i2v"
|
||||
--training_type "lora"
|
||||
--training_type "sft"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/path/to/output/dir"
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/path/to/data/dir"
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
--image_column "images.txt"
|
||||
--train_resolution "81x768x1360"
|
||||
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10
|
||||
--seed 42
|
||||
|
||||
######### Please keep consistent with deepspeed config file ##########
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"]
|
||||
--seed 42
|
||||
########################################################################
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
@ -44,22 +47,23 @@ SYSTEM_ARGS=(
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 200
|
||||
--checkpointing_steps 5
|
||||
--checkpointing_limit 10
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation False
|
||||
--validation_dir "/path/to/validation/dir"
|
||||
--validation_steps 400
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--validation_images "images.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch train.py \
|
||||
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
@ -8,30 +8,33 @@ MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B"
|
||||
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
|
||||
--model_type "t2v"
|
||||
--training_type "lora"
|
||||
--training_type "sft"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/path/to/output/dir"
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/path/to/data/dir"
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
--train_resolution "81x768x1360"
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10
|
||||
--seed 42
|
||||
|
||||
######### Please keep consistent with deepspeed config file ##########
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"]
|
||||
--seed 42
|
||||
########################################################################
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
@ -43,21 +46,22 @@ SYSTEM_ARGS=(
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 200
|
||||
--checkpointing_steps 5
|
||||
--checkpointing_limit 10
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation False
|
||||
--validation_dir "/path/to/validation/dir"
|
||||
--validation_steps 400
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch train.py \
|
||||
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
@ -84,6 +84,8 @@ class Trainer:
|
||||
self._init_logging()
|
||||
self._init_directories()
|
||||
|
||||
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
|
||||
|
||||
def _init_distributed(self):
|
||||
logging_dir = Path(self.args.output_dir, "logs")
|
||||
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
||||
@ -246,9 +248,9 @@ class Trainer:
|
||||
self.components.transformer.add_adapter(transformer_lora_config)
|
||||
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
||||
|
||||
# Load components needed for training to GPU (except transformer),
|
||||
# and cast them to the specified data type
|
||||
self.__move_components_to_device(dtype=weight_dtype)
|
||||
# Load components needed for training to GPU (except transformer), and cast them to the specified data type
|
||||
ignore_list = ["transformer"] + self.UNLOAD_LIST
|
||||
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list)
|
||||
|
||||
if self.args.gradient_checkpointing:
|
||||
self.components.transformer.enable_gradient_checkpointing()
|
||||
@ -406,6 +408,7 @@ class Trainer:
|
||||
generator = generator.manual_seed(self.args.seed)
|
||||
self.state.generator = generator
|
||||
|
||||
free_memory()
|
||||
for epoch in range(first_epoch, self.args.train_epochs):
|
||||
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
||||
|
||||
@ -497,6 +500,13 @@ class Trainer:
|
||||
##### Initialize pipeline #####
|
||||
pipe = self.initialize_pipeline()
|
||||
|
||||
if self.state.using_deepspeed:
|
||||
# Can't using model_cpu_offload in deepspeed,
|
||||
# so we need to move all components in pipe to device
|
||||
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
|
||||
else:
|
||||
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
||||
|
||||
@ -508,7 +518,7 @@ class Trainer:
|
||||
|
||||
all_processes_artifacts = []
|
||||
for i in range(num_validation_samples):
|
||||
if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
||||
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
||||
# Skip current validation on all processes but one
|
||||
if i % accelerator.num_processes != accelerator.process_index:
|
||||
continue
|
||||
@ -539,7 +549,7 @@ class Trainer:
|
||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||
|
||||
if (
|
||||
self.accelerator.deepspeed_plugin is not None
|
||||
self.state.using_deepspeed
|
||||
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
||||
and not accelerator.is_main_process
|
||||
):
|
||||
@ -599,22 +609,25 @@ class Trainer:
|
||||
step=step,
|
||||
)
|
||||
|
||||
pipe.remove_all_hooks()
|
||||
########## Clean up ##########
|
||||
if self.state.using_deepspeed:
|
||||
del pipe
|
||||
# Unload models except those needed for training
|
||||
self.__move_components_to_cpu()
|
||||
|
||||
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
|
||||
else:
|
||||
pipe.remove_all_hooks()
|
||||
del pipe
|
||||
# Load models except those not needed for training
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype)
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
|
||||
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
|
||||
if self.accelerator.state.deepspeed_plugin is None:
|
||||
# Change trainable weights back to fp32
|
||||
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
free_memory()
|
||||
accelerator.wait_for_everyone()
|
||||
################################
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||
@ -668,25 +681,20 @@ class Trainer:
|
||||
else:
|
||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||
|
||||
def __move_components_to_device(self, dtype):
|
||||
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
|
||||
ignore_list = set(ignore_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in self.UNLOAD_LIST:
|
||||
continue
|
||||
|
||||
# We don't need to move transformer to device
|
||||
# because we will prepare it in the `prepare_for_training()`
|
||||
if name == "transformer":
|
||||
continue
|
||||
|
||||
if name not in ignore_list:
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
||||
|
||||
def __move_components_to_cpu(self):
|
||||
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
||||
unload_list = set(unload_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in self.UNLOAD_LIST:
|
||||
if name in unload_list:
|
||||
setattr(self.components, name, component.to("cpu"))
|
||||
|
||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||
|
Loading…
x
Reference in New Issue
Block a user