mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
feat: support DeepSpeed ZeRO-3 and optimize peak memory usage
- Add DeepSpeed ZeRO-3 configuration support - Optimize memory usage during training - Rename training scripts to reflect ZeRO usage - Update related configuration files and trainers
This commit is contained in:
parent
2f275e82b5
commit
fdb9820949
@ -1,17 +1,11 @@
|
|||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
|
|
||||||
# gpu_ids: "0" # 0,1,2,3,4,5,6,7
|
gpu_ids: "0,1,2,4"
|
||||||
# num_processes: 1
|
num_processes: 4 # should be the same as the number of GPUs
|
||||||
|
|
||||||
gpu_ids: all
|
|
||||||
num_processes: 8
|
|
||||||
|
|
||||||
debug: false
|
debug: false
|
||||||
deepspeed_config:
|
deepspeed_config:
|
||||||
deepspeed_config_file: /path/to/your/configs/zero2.yaml
|
deepspeed_config_file: /absolute/path/to/your/deepspeed_config.yaml # e.g. /home/user/cogvideo/finetune/configs/zero2.yaml
|
||||||
# deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml
|
|
||||||
# deepspeed_config_file: /path/to/your/configs/zero3.yaml
|
|
||||||
# deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml
|
|
||||||
zero3_init_flag: false
|
zero3_init_flag: false
|
||||||
distributed_type: DEEPSPEED
|
distributed_type: DEEPSPEED
|
||||||
downcast_bf16: 'no'
|
downcast_bf16: 'no'
|
||||||
|
@ -1,27 +1,9 @@
|
|||||||
import dotmap
|
|
||||||
from diffusers import CogVideoXImageToVideoPipeline
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from finetune.utils import unwrap_model
|
|
||||||
|
|
||||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||||
from ..utils import register
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
||||||
@override
|
pass
|
||||||
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
|
||||||
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
|
||||||
self.components.transformer.config.update(origin_model.config)
|
|
||||||
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
|
||||||
pipe = CogVideoXImageToVideoPipeline(
|
|
||||||
tokenizer=self.components.tokenizer,
|
|
||||||
text_encoder=self.components.text_encoder,
|
|
||||||
vae=self.components.vae,
|
|
||||||
transformer=self.components.transformer,
|
|
||||||
scheduler=self.components.scheduler,
|
|
||||||
)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
||||||
|
@ -1,27 +1,9 @@
|
|||||||
import dotmap
|
|
||||||
from diffusers import CogVideoXPipeline
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from finetune.utils import unwrap_model
|
|
||||||
|
|
||||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||||
from ..utils import register
|
from ..utils import register
|
||||||
|
|
||||||
|
|
||||||
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
||||||
@override
|
pass
|
||||||
def initialize_pipeline(self) -> CogVideoXPipeline:
|
|
||||||
origin_model = unwrap_model(self.accelerator, self.components.transformer)
|
|
||||||
self.components.transformer.config.update(origin_model.config)
|
|
||||||
self.components.transformer.config = dotmap.DotMap(self.components.transformer.config)
|
|
||||||
pipe = CogVideoXPipeline(
|
|
||||||
tokenizer=self.components.tokenizer,
|
|
||||||
text_encoder=self.components.text_encoder,
|
|
||||||
vae=self.components.vae,
|
|
||||||
transformer=self.components.transformer,
|
|
||||||
scheduler=self.components.scheduler,
|
|
||||||
)
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|
||||||
|
@ -8,13 +8,13 @@ from pydantic import BaseModel
|
|||||||
class State(BaseModel):
|
class State(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
train_frames: int # user-defined training frames, **containing one image padding frame**
|
train_frames: int
|
||||||
train_height: int
|
train_height: int
|
||||||
train_width: int
|
train_width: int
|
||||||
|
|
||||||
transformer_config: Dict[str, Any] = None
|
transformer_config: Dict[str, Any] = None
|
||||||
|
|
||||||
weight_dtype: torch.dtype = torch.float32
|
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
|
||||||
num_trainable_parameters: int = 0
|
num_trainable_parameters: int = 0
|
||||||
overwrote_max_train_steps: bool = False
|
overwrote_max_train_steps: bool = False
|
||||||
num_update_steps_per_epoch: int = 0
|
num_update_steps_per_epoch: int = 0
|
||||||
@ -25,3 +25,5 @@ class State(BaseModel):
|
|||||||
validation_prompts: List[str] = []
|
validation_prompts: List[str] = []
|
||||||
validation_images: List[Path | None] = []
|
validation_images: List[Path | None] = []
|
||||||
validation_videos: List[Path | None] = []
|
validation_videos: List[Path | None] = []
|
||||||
|
|
||||||
|
using_deepspeed: bool = False
|
||||||
|
@ -8,31 +8,34 @@ MODEL_ARGS=(
|
|||||||
--model_path "THUDM/CogVideoX1.5-5B-I2V"
|
--model_path "THUDM/CogVideoX1.5-5B-I2V"
|
||||||
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
|
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
|
||||||
--model_type "i2v"
|
--model_type "i2v"
|
||||||
--training_type "lora"
|
--training_type "sft"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output Configuration
|
# Output Configuration
|
||||||
OUTPUT_ARGS=(
|
OUTPUT_ARGS=(
|
||||||
--output_dir "/path/to/output/dir"
|
--output_dir "/absolute/path/to/your/output_dir"
|
||||||
--report_to "tensorboard"
|
--report_to "tensorboard"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Data Configuration
|
# Data Configuration
|
||||||
DATA_ARGS=(
|
DATA_ARGS=(
|
||||||
--data_root "/path/to/data/dir"
|
--data_root "/absolute/path/to/your/data_root"
|
||||||
--caption_column "prompt.txt"
|
--caption_column "prompt.txt"
|
||||||
--video_column "videos.txt"
|
--video_column "videos.txt"
|
||||||
--image_column "images.txt"
|
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
|
||||||
--train_resolution "81x768x1360"
|
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training Configuration
|
# Training Configuration
|
||||||
TRAIN_ARGS=(
|
TRAIN_ARGS=(
|
||||||
--train_epochs 10
|
--train_epochs 10
|
||||||
|
--seed 42
|
||||||
|
|
||||||
|
######### Please keep consistent with deepspeed config file ##########
|
||||||
--batch_size 1
|
--batch_size 1
|
||||||
--gradient_accumulation_steps 1
|
--gradient_accumulation_steps 1
|
||||||
--mixed_precision "bf16" # ["no", "fp16"]
|
--mixed_precision "bf16" # ["no", "fp16"]
|
||||||
--seed 42
|
########################################################################
|
||||||
)
|
)
|
||||||
|
|
||||||
# System Configuration
|
# System Configuration
|
||||||
@ -44,22 +47,23 @@ SYSTEM_ARGS=(
|
|||||||
|
|
||||||
# Checkpointing Configuration
|
# Checkpointing Configuration
|
||||||
CHECKPOINT_ARGS=(
|
CHECKPOINT_ARGS=(
|
||||||
--checkpointing_steps 200
|
--checkpointing_steps 5
|
||||||
--checkpointing_limit 10
|
--checkpointing_limit 10
|
||||||
|
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validation Configuration
|
# Validation Configuration
|
||||||
VALIDATION_ARGS=(
|
VALIDATION_ARGS=(
|
||||||
--do_validation False
|
--do_validation false # ["true", "false"]
|
||||||
--validation_dir "/path/to/validation/dir"
|
--validation_dir "/absolute/path/to/validation_set"
|
||||||
--validation_steps 400
|
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||||
--validation_prompts "prompts.txt"
|
--validation_prompts "prompts.txt"
|
||||||
--validation_images "images.txt"
|
--validation_images "images.txt"
|
||||||
--gen_fps 16
|
--gen_fps 16
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine all arguments and launch training
|
# Combine all arguments and launch training
|
||||||
accelerate launch train.py \
|
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||||
"${MODEL_ARGS[@]}" \
|
"${MODEL_ARGS[@]}" \
|
||||||
"${OUTPUT_ARGS[@]}" \
|
"${OUTPUT_ARGS[@]}" \
|
||||||
"${DATA_ARGS[@]}" \
|
"${DATA_ARGS[@]}" \
|
@ -8,30 +8,33 @@ MODEL_ARGS=(
|
|||||||
--model_path "THUDM/CogVideoX1.5-5B"
|
--model_path "THUDM/CogVideoX1.5-5B"
|
||||||
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
|
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
|
||||||
--model_type "t2v"
|
--model_type "t2v"
|
||||||
--training_type "lora"
|
--training_type "sft"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output Configuration
|
# Output Configuration
|
||||||
OUTPUT_ARGS=(
|
OUTPUT_ARGS=(
|
||||||
--output_dir "/path/to/output/dir"
|
--output_dir "/absolute/path/to/your/output_dir"
|
||||||
--report_to "tensorboard"
|
--report_to "tensorboard"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Data Configuration
|
# Data Configuration
|
||||||
DATA_ARGS=(
|
DATA_ARGS=(
|
||||||
--data_root "/path/to/data/dir"
|
--data_root "/absolute/path/to/your/data_root"
|
||||||
--caption_column "prompt.txt"
|
--caption_column "prompt.txt"
|
||||||
--video_column "videos.txt"
|
--video_column "videos.txt"
|
||||||
--train_resolution "81x768x1360"
|
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training Configuration
|
# Training Configuration
|
||||||
TRAIN_ARGS=(
|
TRAIN_ARGS=(
|
||||||
--train_epochs 10
|
--train_epochs 10
|
||||||
|
--seed 42
|
||||||
|
|
||||||
|
######### Please keep consistent with deepspeed config file ##########
|
||||||
--batch_size 1
|
--batch_size 1
|
||||||
--gradient_accumulation_steps 1
|
--gradient_accumulation_steps 1
|
||||||
--mixed_precision "bf16" # ["no", "fp16"]
|
--mixed_precision "bf16" # ["no", "fp16"]
|
||||||
--seed 42
|
########################################################################
|
||||||
)
|
)
|
||||||
|
|
||||||
# System Configuration
|
# System Configuration
|
||||||
@ -43,21 +46,22 @@ SYSTEM_ARGS=(
|
|||||||
|
|
||||||
# Checkpointing Configuration
|
# Checkpointing Configuration
|
||||||
CHECKPOINT_ARGS=(
|
CHECKPOINT_ARGS=(
|
||||||
--checkpointing_steps 200
|
--checkpointing_steps 5
|
||||||
--checkpointing_limit 10
|
--checkpointing_limit 10
|
||||||
|
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validation Configuration
|
# Validation Configuration
|
||||||
VALIDATION_ARGS=(
|
VALIDATION_ARGS=(
|
||||||
--do_validation False
|
--do_validation false # ["true", "false"]
|
||||||
--validation_dir "/path/to/validation/dir"
|
--validation_dir "/absolute/path/to/validation_set"
|
||||||
--validation_steps 400
|
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||||
--validation_prompts "prompts.txt"
|
--validation_prompts "prompts.txt"
|
||||||
--gen_fps 16
|
--gen_fps 16
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine all arguments and launch training
|
# Combine all arguments and launch training
|
||||||
accelerate launch train.py \
|
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||||
"${MODEL_ARGS[@]}" \
|
"${MODEL_ARGS[@]}" \
|
||||||
"${OUTPUT_ARGS[@]}" \
|
"${OUTPUT_ARGS[@]}" \
|
||||||
"${DATA_ARGS[@]}" \
|
"${DATA_ARGS[@]}" \
|
@ -84,6 +84,8 @@ class Trainer:
|
|||||||
self._init_logging()
|
self._init_logging()
|
||||||
self._init_directories()
|
self._init_directories()
|
||||||
|
|
||||||
|
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
|
||||||
|
|
||||||
def _init_distributed(self):
|
def _init_distributed(self):
|
||||||
logging_dir = Path(self.args.output_dir, "logs")
|
logging_dir = Path(self.args.output_dir, "logs")
|
||||||
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
||||||
@ -246,9 +248,9 @@ class Trainer:
|
|||||||
self.components.transformer.add_adapter(transformer_lora_config)
|
self.components.transformer.add_adapter(transformer_lora_config)
|
||||||
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
||||||
|
|
||||||
# Load components needed for training to GPU (except transformer),
|
# Load components needed for training to GPU (except transformer), and cast them to the specified data type
|
||||||
# and cast them to the specified data type
|
ignore_list = ["transformer"] + self.UNLOAD_LIST
|
||||||
self.__move_components_to_device(dtype=weight_dtype)
|
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list)
|
||||||
|
|
||||||
if self.args.gradient_checkpointing:
|
if self.args.gradient_checkpointing:
|
||||||
self.components.transformer.enable_gradient_checkpointing()
|
self.components.transformer.enable_gradient_checkpointing()
|
||||||
@ -406,6 +408,7 @@ class Trainer:
|
|||||||
generator = generator.manual_seed(self.args.seed)
|
generator = generator.manual_seed(self.args.seed)
|
||||||
self.state.generator = generator
|
self.state.generator = generator
|
||||||
|
|
||||||
|
free_memory()
|
||||||
for epoch in range(first_epoch, self.args.train_epochs):
|
for epoch in range(first_epoch, self.args.train_epochs):
|
||||||
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
||||||
|
|
||||||
@ -497,6 +500,13 @@ class Trainer:
|
|||||||
##### Initialize pipeline #####
|
##### Initialize pipeline #####
|
||||||
pipe = self.initialize_pipeline()
|
pipe = self.initialize_pipeline()
|
||||||
|
|
||||||
|
if self.state.using_deepspeed:
|
||||||
|
# Can't using model_cpu_offload in deepspeed,
|
||||||
|
# so we need to move all components in pipe to device
|
||||||
|
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||||
|
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
|
||||||
|
else:
|
||||||
|
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
||||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||||
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
||||||
|
|
||||||
@ -508,7 +518,7 @@ class Trainer:
|
|||||||
|
|
||||||
all_processes_artifacts = []
|
all_processes_artifacts = []
|
||||||
for i in range(num_validation_samples):
|
for i in range(num_validation_samples):
|
||||||
if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
||||||
# Skip current validation on all processes but one
|
# Skip current validation on all processes but one
|
||||||
if i % accelerator.num_processes != accelerator.process_index:
|
if i % accelerator.num_processes != accelerator.process_index:
|
||||||
continue
|
continue
|
||||||
@ -539,7 +549,7 @@ class Trainer:
|
|||||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.accelerator.deepspeed_plugin is not None
|
self.state.using_deepspeed
|
||||||
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
||||||
and not accelerator.is_main_process
|
and not accelerator.is_main_process
|
||||||
):
|
):
|
||||||
@ -599,22 +609,25 @@ class Trainer:
|
|||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe.remove_all_hooks()
|
########## Clean up ##########
|
||||||
|
if self.state.using_deepspeed:
|
||||||
del pipe
|
del pipe
|
||||||
# Unload models except those needed for training
|
# Unload models except those needed for training
|
||||||
self.__move_components_to_cpu()
|
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
|
||||||
|
else:
|
||||||
|
pipe.remove_all_hooks()
|
||||||
|
del pipe
|
||||||
# Load models except those not needed for training
|
# Load models except those not needed for training
|
||||||
self.__move_components_to_device(dtype=self.state.weight_dtype)
|
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
|
||||||
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||||
|
|
||||||
if self.accelerator.state.deepspeed_plugin is None:
|
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
||||||
# Change trainable weights back to fp32
|
|
||||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
free_memory()
|
free_memory()
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
################################
|
||||||
|
|
||||||
memory_statistics = get_memory_statistics()
|
memory_statistics = get_memory_statistics()
|
||||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||||
@ -668,25 +681,20 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||||
|
|
||||||
def __move_components_to_device(self, dtype):
|
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
|
||||||
|
ignore_list = set(ignore_list)
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, "to"):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name in self.UNLOAD_LIST:
|
if name not in ignore_list:
|
||||||
continue
|
|
||||||
|
|
||||||
# We don't need to move transformer to device
|
|
||||||
# because we will prepare it in the `prepare_for_training()`
|
|
||||||
if name == "transformer":
|
|
||||||
continue
|
|
||||||
|
|
||||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
||||||
|
|
||||||
def __move_components_to_cpu(self):
|
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
||||||
|
unload_list = set(unload_list)
|
||||||
components = self.components.model_dump()
|
components = self.components.model_dump()
|
||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, "to"):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name in self.UNLOAD_LIST:
|
if name in unload_list:
|
||||||
setattr(self.components, name, component.to("cpu"))
|
setattr(self.components, name, component.to("cpu"))
|
||||||
|
|
||||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user