diff --git a/finetune/accelerate_config.yaml b/finetune/accelerate_config.yaml new file mode 100644 index 0000000..ee3aed5 --- /dev/null +++ b/finetune/accelerate_config.yaml @@ -0,0 +1,27 @@ +compute_environment: LOCAL_MACHINE + +# gpu_ids: "0" # 0,1,2,3,4,5,6,7 +# num_processes: 1 + +gpu_ids: all +num_processes: 8 + +debug: false +deepspeed_config: + deepspeed_config_file: /path/to/your/configs/zero2.yaml + # deepspeed_config_file: /path/to/your/configs/zero2_offload.yaml + # deepspeed_config_file: /path/to/your/configs/zero3.yaml + # deepspeed_config_file: /path/to/your/configs/zero3_offload.yaml + zero3_init_flag: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +num_machines: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/finetune/configs/zero2.yaml b/finetune/configs/zero2.yaml new file mode 100644 index 0000000..96afa13 --- /dev/null +++ b/finetune/configs/zero2.yaml @@ -0,0 +1,38 @@ +{ + "bf16": { + "enabled": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": 1, + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/finetune/configs/zero2_offload.yaml b/finetune/configs/zero2_offload.yaml new file mode 100644 index 0000000..b542665 --- /dev/null +++ b/finetune/configs/zero2_offload.yaml @@ -0,0 +1,42 @@ +{ + "bf16": { + "enabled": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + } + }, + "gradient_accumulation_steps": 1, + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/finetune/configs/zero3.yaml b/finetune/configs/zero3.yaml new file mode 100644 index 0000000..8f73fe8 --- /dev/null +++ b/finetune/configs/zero3.yaml @@ -0,0 +1,43 @@ +{ + "bf16": { + "enabled": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "sub_group_size": 1e9, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": "auto", + "stage3_prefetch_bucket_size": 5e8, + "stage3_param_persistence_threshold": 1e5 + }, + "gradient_accumulation_steps": 1, + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/finetune/configs/zero3_offload.yaml b/finetune/configs/zero3_offload.yaml new file mode 100644 index 0000000..9a2c502 --- /dev/null +++ b/finetune/configs/zero3_offload.yaml @@ -0,0 +1,51 @@ +{ + "bf16": { + "enabled": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "weight_decay": "auto", + "torch_adam": true, + "adam_w_mode": true + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "sub_group_size": 1e9, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": "auto", + "stage3_prefetch_bucket_size": 5e8, + "stage3_param_persistence_threshold": 1e6 + }, + "gradient_accumulation_steps": 1, + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/finetune/models/cogvideox1dot5_i2v/sft_trainer.py b/finetune/models/cogvideox1dot5_i2v/sft_trainer.py new file mode 100644 index 0000000..5023178 --- /dev/null +++ b/finetune/models/cogvideox1dot5_i2v/sft_trainer.py @@ -0,0 +1,9 @@ +from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer +from ..utils import register + + +class CogVideoX1dot5I2VSftTrainer(CogVideoXI2VSftTrainer): + pass + + +register("cogvideox1.5-i2v", "sft", CogVideoX1dot5I2VSftTrainer) diff --git a/finetune/models/cogvideox1dot5_t2v/sft_trainer.py b/finetune/models/cogvideox1dot5_t2v/sft_trainer.py new file mode 100644 index 0000000..16294cb --- /dev/null +++ b/finetune/models/cogvideox1dot5_t2v/sft_trainer.py @@ -0,0 +1,9 @@ +from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer +from ..utils import register + + +class CogVideoX1dot5T2VSftTrainer(CogVideoXT2VSftTrainer): + pass + + +register("cogvideox1.5-t2v", "sft", CogVideoX1dot5T2VSftTrainer) diff --git a/finetune/models/cogvideox_i2v/sft_trainer.py b/finetune/models/cogvideox_i2v/sft_trainer.py new file mode 100644 index 0000000..0cb5f12 --- /dev/null +++ b/finetune/models/cogvideox_i2v/sft_trainer.py @@ -0,0 +1,27 @@ +import dotmap +from diffusers import CogVideoXImageToVideoPipeline +from typing_extensions import override + +from finetune.utils import unwrap_model + +from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer +from ..utils import register + + +class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer): + @override + def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline: + origin_model = unwrap_model(self.accelerator, self.components.transformer) + self.components.transformer.config.update(origin_model.config) + self.components.transformer.config = dotmap.DotMap(self.components.transformer.config) + pipe = CogVideoXImageToVideoPipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=self.components.transformer, + scheduler=self.components.scheduler, + ) + return pipe + + +register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer) diff --git a/finetune/models/cogvideox_t2v/sft_trainer.py b/finetune/models/cogvideox_t2v/sft_trainer.py new file mode 100644 index 0000000..85b80fa --- /dev/null +++ b/finetune/models/cogvideox_t2v/sft_trainer.py @@ -0,0 +1,27 @@ +import dotmap +from diffusers import CogVideoXPipeline +from typing_extensions import override + +from finetune.utils import unwrap_model + +from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer +from ..utils import register + + +class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer): + @override + def initialize_pipeline(self) -> CogVideoXPipeline: + origin_model = unwrap_model(self.accelerator, self.components.transformer) + self.components.transformer.config.update(origin_model.config) + self.components.transformer.config = dotmap.DotMap(self.components.transformer.config) + pipe = CogVideoXPipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=self.components.transformer, + scheduler=self.components.scheduler, + ) + return pipe + + +register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer) diff --git a/finetune/models/utils.py b/finetune/models/utils.py index dcc963f..ef3ea5b 100644 --- a/finetune/models/utils.py +++ b/finetune/models/utils.py @@ -15,17 +15,12 @@ def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls trainer_cls (Trainer): Trainer class to register. """ - # Check if model_name exists in SUPPORTED_MODELS + # Check if model_name and training_type exists in SUPPORTED_MODELS if model_name not in SUPPORTED_MODELS: SUPPORTED_MODELS[model_name] = {} else: - raise ValueError(f"Model {model_name} already exists") - - # Check if training_type exists for this model - if training_type not in SUPPORTED_MODELS[model_name]: - SUPPORTED_MODELS[model_name][training_type] = {} - else: - raise ValueError(f"Training type {training_type} already exists for model {model_name}") + if training_type in SUPPORTED_MODELS[model_name]: + raise ValueError(f"Training type {training_type} already exists for model {model_name}") SUPPORTED_MODELS[model_name][training_type] = trainer_cls diff --git a/finetune/trainer.py b/finetune/trainer.py index 53fc193..e3ec8e6 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -1,3 +1,4 @@ +import hashlib import json import logging import math @@ -71,7 +72,7 @@ class Trainer: train_width=self.args.train_resolution[2], ) - self.components = Components() + self.components: Components = self.load_components() self.accelerator: Accelerator = None self.dataset: Dataset = None self.data_loader: DataLoader = None @@ -145,9 +146,6 @@ class Trainer: def prepare_models(self) -> None: logger.info("Initializing models") - # Initialize model components - self.components = self.load_components() - if self.components.vae is not None: if self.args.enable_slicing: self.components.vae.enable_slicing() @@ -159,15 +157,11 @@ class Trainer: def prepare_dataset(self) -> None: logger.info("Initializing dataset and dataloader") - # self.state.train_frames includes one padding frame for image conditioning - # so we only sample train_frames - 1 frames from the actual video - sample_frames = self.state.train_frames - 1 - if self.args.model_type == "i2v": self.dataset = I2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - max_num_frames=sample_frames, + max_num_frames=self.state.train_frames, height=self.state.train_height, width=self.state.train_width, trainer=self, @@ -176,7 +170,7 @@ class Trainer: self.dataset = T2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - max_num_frames=sample_frames, + max_num_frames=self.state.train_frames, height=self.state.train_height, width=self.state.train_width, trainer=self, @@ -223,12 +217,7 @@ class Trainer: def prepare_trainable_parameters(self): logger.info("Initializing trainable parameters") - # For now only lora is supported - for attr_name, component in vars(self.components).items(): - if hasattr(component, "requires_grad_"): - component.requires_grad_(False) - - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # For mixed precision training we cast all non-trainable weights to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = self.state.weight_dtype @@ -238,35 +227,47 @@ class Trainer: "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - self.__load_components() + # For LoRA, we freeze all the parameters + # For SFT, we train all the parameters in transformer model + for attr_name, component in vars(self.components).items(): + if hasattr(component, "requires_grad_"): + if self.args.training_type == "sft" and attr_name == "transformer": + component.requires_grad_(True) + else: + component.requires_grad_(False) + + if self.args.training_type == "lora": + transformer_lora_config = LoraConfig( + r=self.args.rank, + lora_alpha=self.args.lora_alpha, + init_lora_weights=True, + target_modules=self.args.target_modules, + ) + self.components.transformer.add_adapter(transformer_lora_config) + self.__prepare_saving_loading_hooks(transformer_lora_config) + + # Load components needed for training to GPU (except transformer), + # and cast them to the specified data type + self.__move_components_to_device(dtype=weight_dtype) if self.args.gradient_checkpointing: self.components.transformer.enable_gradient_checkpointing() - transformer_lora_config = LoraConfig( - r=self.args.rank, - lora_alpha=self.args.lora_alpha, - init_lora_weights=True, - target_modules=self.args.target_modules, - ) - self.components.transformer.add_adapter(transformer_lora_config) - self.__prepare_saving_loading_hooks(transformer_lora_config) - def prepare_optimizer(self) -> None: logger.info("Initializing optimizer and lr scheduler") # Make sure the trainable params are in float32 - if self.args.mixed_precision != "no": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([self.components.transformer], dtype=torch.float32) + cast_training_params([self.components.transformer], dtype=torch.float32) - transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) + # For LoRA, we only want to train the LoRA weights + # For SFT, we want to train all the parameters + trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) transformer_parameters_with_lr = { - "params": transformer_lora_parameters, + "params": trainable_parameters, "lr": self.args.learning_rate, } params_to_optimize = [transformer_parameters_with_lr] - self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters) + self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters) use_deepspeed_opt = ( self.accelerator.state.deepspeed_plugin is not None @@ -502,13 +503,15 @@ class Trainer: # Convert all model weights to training dtype # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 pipe = pipe.to(dtype=self.state.weight_dtype) + ################################# all_processes_artifacts = [] for i in range(num_validation_samples): - # Skip current validation on all processes but one - if i % accelerator.num_processes != accelerator.process_index: - continue + if self.accelerator.deepspeed_plugin and self.accelerator.deepspeed_plugin.zero_stage != 3: + # Skip current validation on all processes but one + if i % accelerator.num_processes != accelerator.process_index: + continue prompt = self.state.validation_prompts[i] image = self.state.validation_images[i] @@ -534,7 +537,19 @@ class Trainer: main_process_only=False, ) validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) + + if ( + self.accelerator.deepspeed_plugin is not None + and self.accelerator.deepspeed_plugin.zero_stage == 3 + and not accelerator.is_main_process + ): + continue + prompt_filename = string_to_filename(prompt)[:25] + # Calculate hash of reversed prompt as a unique identifier + reversed_prompt = prompt[::-1] + hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5] + artifacts = { "image": {"type": "image", "value": image}, "video": {"type": "video", "value": video}, @@ -553,7 +568,7 @@ class Trainer: continue extension = "png" if artifact_type == "image" else "mp4" - filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}" + filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}" validation_path = self.args.output_dir / "validation_res" validation_path.mkdir(parents=True, exist_ok=True) filename = str(validation_path / filename) @@ -587,11 +602,15 @@ class Trainer: pipe.remove_all_hooks() del pipe # Unload models except those needed for training - self.__unload_components() + self.__move_components_to_cpu() + # Load models except those not needed for training - self.__load_components() - # Change LoRA weights back to fp32 - cast_training_params([self.components.transformer], dtype=torch.float32) + self.__move_components_to_device(dtype=self.state.weight_dtype) + self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) + + if self.accelerator.state.deepspeed_plugin is None: + # Change trainable weights back to fp32 + cast_training_params([self.components.transformer], dtype=torch.float32) accelerator.wait_for_everyone() @@ -649,16 +668,21 @@ class Trainer: else: raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") - def __load_components(self): + def __move_components_to_device(self, dtype): components = self.components.model_dump() for name, component in components.items(): if not isinstance(component, type) and hasattr(component, "to"): if name in self.UNLOAD_LIST: continue - # setattr(self.components, name, component.to(self.accelerator.device)) - setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype)) - def __unload_components(self): + # We don't need to move transformer to device + # because we will prepare it in the `prepare_for_training()` + if name == "transformer": + continue + + setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) + + def __move_components_to_cpu(self): components = self.components.model_dump() for name, component in components.items(): if not isinstance(component, type) and hasattr(component, "to"): @@ -723,13 +747,6 @@ class Trainer: f" {unexpected_keys}. " ) - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if self.args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer_]) - self.accelerator.register_save_state_pre_hook(save_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook)