From 60f6a3d7eed6c8d29abce8795d3a85f19941bf72 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sun, 29 Dec 2024 15:27:43 +0000 Subject: [PATCH] feat: add base trainer implementation and training script - Add Trainer base class with core training loop functionality - Implement distributed training setup with Accelerate - Add training script with model/trainer initialization - Support LoRA fine-tuning with checkpointing and validation --- finetune/train.py | 18 ++ finetune/trainer.py | 512 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 530 insertions(+) create mode 100644 finetune/train.py create mode 100644 finetune/trainer.py diff --git a/finetune/train.py b/finetune/train.py new file mode 100644 index 0000000..5f49f4b --- /dev/null +++ b/finetune/train.py @@ -0,0 +1,18 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +from finetune.schemas import Args +from finetune.models.utils import get_model_cls + + +def main(): + args = Args.parse_args() + trainer_cls = get_model_cls(args.model_name, args.training_type) + trainer = trainer_cls(args) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/finetune/trainer.py b/finetune/trainer.py new file mode 100644 index 0000000..cc9882c --- /dev/null +++ b/finetune/trainer.py @@ -0,0 +1,512 @@ +import torch +import logging +import transformers +import diffusers +import math +import json +import multiprocessing + +from datetime import timedelta +from pathlib import Path +from tqdm import tqdm +from typing import Dict, Any, List + +from torch.utils.data import Dataset, DataLoader +from accelerate.logging import get_logger +from accelerate.accelerator import Accelerator, DistributedType +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, + gather_object, +) + +from diffusers.optimization import get_scheduler +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict + +from finetune.schemas import Args, State, Components +from finetune.utils import ( + unwrap_model, cast_training_params, + get_optimizer, + + get_memory_statistics, + free_memory, + + get_latest_ckpt_path_to_resume_from, + get_intermediate_ckpt_path, + get_latest_ckpt_path_to_resume_from, + get_intermediate_ckpt_path, +) +from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler + +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + +_DTYPE_MAP = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +class Trainer: + + def __init__(self, args: Args) -> None: + self.args = args + self.state = State(weight_dtype=self.__get_training_dtype()) + + self.components = Components() + self.accelerator: Accelerator = None + self.dataset: Dataset = None + self.data_loader: DataLoader = None + + self.optimizer = None + self.lr_scheduler = None + + self._init_distributed() + self._init_logging() + self._init_directories() + + + def _init_distributed(self): + logging_dir = Path(self.args.output_dir, "logs") + project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs( + backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) + ) + mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision + report_to = None if self.args.report_to.lower() == "none" else self.args.report_to + + accelerator = Accelerator( + project_config=project_config, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + mixed_precision=mixed_precision, + log_with=report_to, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + self.accelerator = accelerator + + if self.args.seed is not None: + set_seed(self.args.seed) + + + def _init_logging(self) -> None: + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=LOG_LEVEL, + ) + if self.accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + logger.info("Initialized Trainer") + logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False) + + + def _init_directories(self) -> None: + if self.accelerator.is_main_process: + self.args.output_dir = Path(self.args.output_dir) + self.args.output_dir.mkdir(parents=True, exist_ok=True) + + + def prepare_dataset(self) -> None: + logger.info("Initializing dataset and dataloader") + + if self.args.model_type == "i2v": + self.dataset = I2VDatasetWithBuckets(**(self.args.model_dump())) + elif self.args.model_type == "t2v": + self.dataset = T2VDatasetWithBuckets(**(self.args.model_dump())) + else: + raise ValueError(f"Invalid model type: {self.args.model_type}") + + self.data_loader = torch.utils.data.DataLoader( + self.dataset, + batch_size=1, + sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), + collate_fn=self.collate_fn, + num_workers=self.args.num_workers, + pin_memory=self.args.pin_memory, + ) + + def prepare_models(self) -> None: + logger.info("Initializing models") + + # Initialize model components + self.components = self.load_components() + + if self.components.vae is not None: + if self.args.enable_slicing: + self.components.vae.enable_slicing() + if self.args.enable_tiling: + self.components.vae.enable_tiling() + + def prepare_trainable_parameters(self): + logger.info("Initializing trainable parameters") + + # For now only lora is supported + for attr_name, component in vars(self.components).items(): + if hasattr(component, 'requires_grad_'): + component.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = self.state.weight_dtype + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + self.__move_components_to_device() + + if self.args.gradient_checkpointing: + self.components.transformer.enable_gradient_checkpointing() + + transformer_lora_config = LoraConfig( + r=self.args.rank, + lora_alpha=self.args.lora_alpha, + init_lora_weights=True, + target_modules=self.args.target_modules, + ) + self.components.transformer.add_adapter(transformer_lora_config) + self.__prepare_saving_loading_hooks(transformer_lora_config) + + def prepare_optimizer(self) -> None: + logger.info("Initializing optimizer and lr scheduler") + + # Make sure the trainable params are in float32 + if self.args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([self.components.transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": self.args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters) + + use_deepspeed_opt = ( + self.accelerator.state.deepspeed_plugin is not None + and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config + ) + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=self.args.optimizer, + learning_rate=self.args.learning_rate, + beta1=self.args.beta1, + beta2=self.args.beta2, + beta3=self.args.beta3, + epsilon=self.args.epsilon, + weight_decay=self.args.weight_decay, + use_deepspeed=use_deepspeed_opt, + ) + + num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) + if self.args.train_steps is None: + self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch + self.state.overwrote_max_train_steps = True + + use_deepspeed_lr_scheduler = ( + self.accelerator.state.deepspeed_plugin is not None + and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config + ) + total_training_steps = self.args.train_steps * self.accelerator.num_processes + num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes + + if use_deepspeed_lr_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=self.args.lr_scheduler, + optimizer=optimizer, + total_num_steps=total_training_steps, + num_warmup_steps=num_warmup_steps, + ) + else: + lr_scheduler = get_scheduler( + name=self.args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_training_steps, + num_cycles=self.args.lr_num_cycles, + power=self.args.lr_power, + ) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def prepare_for_training(self) -> None: + self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( + self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) + if self.state.overwrote_max_train_steps: + self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) + self.state.num_update_steps_per_epoch = num_update_steps_per_epoch + + def prepare_trackers(self) -> None: + logger.info("Initializing trackers") + + tracker_name = self.args.tracker_name or "finetrainers-experiment" + self.accelerator.init_trackers(tracker_name, config=self.args.model_dump()) + + def train(self) -> None: + logger.info("Starting training") + + memory_statistics = get_memory_statistics() + logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") + + self.state.total_batch_size_count = ( + self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps + ) + info = { + "trainable parameters": self.state.num_trainable_parameters, + "total samples": len(self.dataset), + "train epochs": self.args.train_epochs, + "train steps": self.args.train_steps, + "batches per device": self.args.batch_size, + "total batches observed per epoch": len(self.data_loader), + "train batch size total count": self.state.total_batch_size_count, + "gradient accumulation steps": self.args.gradient_accumulation_steps, + } + logger.info(f"Training configuration: {json.dumps(info, indent=4)}") + + global_step = 0 + first_epoch = 0 + initial_global_step = 0 + + # Potentially load in the weights and states from a previous save + ( + resume_from_checkpoint_path, + initial_global_step, + global_step, + first_epoch, + ) = get_latest_ckpt_path_to_resume_from( + resume_from_checkpoint=self.args.resume_from_checkpoint, + num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, + ) + if resume_from_checkpoint_path is not None: + self.accelerator.load_state(resume_from_checkpoint_path) + + progress_bar = tqdm( + range(0, self.args.train_steps), + initial=initial_global_step, + desc="Training steps", + disable=not self.accelerator.is_local_main_process, + ) + + accelerator = self.accelerator + generator = torch.Generator(device=accelerator.device) + if self.args.seed is not None: + generator = generator.manual_seed(self.args.seed) + + for epoch in range(first_epoch, self.args.train_epochs): + logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") + + self.components.transformer.train() + models_to_accumulate = [self.components.transformer] + + for step, batch in enumerate(self.data_loader): + logger.debug(f"Starting step {step + 1}") + logs = {} + + with accelerator.accumulate(models_to_accumulate): + # These weighting schemes use a uniform timestep sampling and instead post-weight the loss + loss = self.compute_loss(batch) + accelerator.backward(loss) + + if accelerator.sync_gradients: + if accelerator.distributed_type == DistributedType.DEEPSPEED: + grad_norm = self.components.transformer.get_global_grad_norm() + # In some cases the grad norm may not return a float + if torch.is_tensor(grad_norm): + grad_norm = grad_norm.item() + else: + grad_norm = accelerator.clip_grad_norm_( + self.components.transformer.parameters(), self.args.max_grad_norm + ) + if torch.is_tensor(grad_norm): + grad_norm = grad_norm.item() + + logs["grad_norm"] = grad_norm + + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + self.__maybe_save_checkpoint(global_step) + + # Maybe run validation + should_run_validation = ( + self.args.validation_steps is not None + and global_step % self.args.validation_steps == 0 + ) + if should_run_validation: + self.validate(global_step) + + logs["loss"] = loss.detach().item() + logs["lr"] = self.lr_scheduler.get_last_lr()[0] + progress_bar.set_postfix(logs) + accelerator.log(logs, step=global_step) + + if global_step >= self.args.train_steps: + break + + memory_statistics = get_memory_statistics() + logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") + + accelerator.wait_for_everyone() + self.__maybe_save_checkpoint(global_step, must_save=True) + self.validate(global_step) + + del self.components + free_memory() + memory_statistics = get_memory_statistics() + logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") + + accelerator.end_training() + + def fit(self): + self.prepare_dataset() + self.prepare_models() + self.prepare_trainable_parameters() + self.prepare_optimizer() + self.prepare_for_training() + self.prepare_trackers() + self.train() + + def collate_fn(self, examples: List[List[Dict[str, Any]]]): + """ + Since we use BucketSampler, the examples parameter is a nested list where the outer list contains only one element, + which is the batch data we need. Therefore, when processing the data, we need to access the batch through examples[0]. + """ + raise NotImplementedError + + def load_components(self) -> Components: + raise NotImplementedError + + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + def validate(self) -> None: + raise NotImplementedError + + def __get_training_dtype(self) -> torch.dtype: + if self.args.mixed_precision == "no": + return _DTYPE_MAP["fp32"] + elif self.args.mixed_precision == "fp16": + return _DTYPE_MAP["fp16"] + elif self.args.mixed_precision == "bf16": + return _DTYPE_MAP["bf16"] + else: + raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") + + def __move_components_to_device(self): + components = self.components.model_dump() + for name, component in components.items(): + if not isinstance(component, type) and hasattr(component, 'to'): + setattr(self.components, name, component.to(self.accelerator.device)) + + def __prepare_saving_loading_hooks(self, transformer_lora_config): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if self.accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance( + unwrap_model(self.accelerator, model), + type(unwrap_model(self.accelerator, self.components.transformer)), + ): + model = unwrap_model(self.accelerator, model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + self.components.pipeline_cls.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + if not self.accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + if isinstance( + unwrap_model(self.accelerator, model), + type(unwrap_model(self.accelerator, self.components.transformer)), + ): + transformer_ = unwrap_model(self.accelerator, model) + else: + raise ValueError( + f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}" + ) + else: + transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained( + self.args.model_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = self.components.piepeline_cls.lora_state_dict(input_dir) + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("transformer.") + } + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if self.args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + self.accelerator.register_save_state_pre_hook(save_model_hook) + self.accelerator.register_load_state_pre_hook(load_model_hook) + + def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False): + if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process: + if must_save or global_step % self.args.checkpointing_steps == 0: + save_path = get_intermediate_ckpt_path( + checkpointing_limit=self.args.checkpointing_limit, + step=global_step, + output_dir=self.args.output_dir, + ) + self.accelerator.save_state(save_path)