diff --git a/finetune/utils/__init__.py b/finetune/utils/__init__.py new file mode 100644 index 0000000..5c6fe5d --- /dev/null +++ b/finetune/utils/__init__.py @@ -0,0 +1,4 @@ +from .torch_utils import * +from .optimizer_utils import * +from .memory_utils import * +from .checkpointing import * diff --git a/finetune/utils/checkpointing.py b/finetune/utils/checkpointing.py new file mode 100644 index 0000000..1797153 --- /dev/null +++ b/finetune/utils/checkpointing.py @@ -0,0 +1,53 @@ +import os +from pathlib import Path +from typing import Tuple +from accelerate.logging import get_logger + +from finetune.constants import LOG_NAME, LOG_LEVEL +from ..utils.file_utils import find_files, delete_files + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def get_latest_ckpt_path_to_resume_from( + resume_from_checkpoint: str | None, num_update_steps_per_epoch: int +) -> Tuple[str | None, int, int, int]: + if resume_from_checkpoint is None: + initial_global_step = 0 + global_step = 0 + first_epoch = 0 + resume_from_checkpoint_path = None + else: + resume_from_checkpoint_path = Path(resume_from_checkpoint) + if not resume_from_checkpoint_path.exists(): + logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") + initial_global_step = 0 + global_step = 0 + first_epoch = 0 + resume_from_checkpoint_path = None + else: + logger.info(f"Resuming from checkpoint {resume_from_checkpoint}") + global_step = int(resume_from_checkpoint_path.name.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch + + +def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str: + # before saving state, check if this save would set us over the `checkpointing_limit` + if checkpointing_limit is not None: + checkpoints = find_files(output_dir, prefix="checkpoint") + + # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpointing_limit: + num_to_remove = len(checkpoints) - checkpointing_limit + 1 + checkpoints_to_remove = checkpoints[0:num_to_remove] + delete_files(checkpoints_to_remove) + + logger.info(f"Checkpointing at step {step}") + save_path = os.path.join(output_dir, f"checkpoint-{step}") + logger.info(f"Saving state to {save_path}") + return save_path diff --git a/finetune/utils/file_utils.py b/finetune/utils/file_utils.py new file mode 100644 index 0000000..f04dd85 --- /dev/null +++ b/finetune/utils/file_utils.py @@ -0,0 +1,47 @@ +import logging +import os +import shutil + +from pathlib import Path +from typing import Any, Dict, List, Union +from accelerate.logging import get_logger +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]: + if not isinstance(dir, Path): + dir = Path(dir) + if not dir.exists(): + return [] + checkpoints = os.listdir(dir.as_posix()) + checkpoints = [c for c in checkpoints if c.startswith(prefix)] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [dir / c for c in checkpoints] + return checkpoints + + +def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None: + if not isinstance(dirs, list): + dirs = [dirs] + dirs = [Path(d) if isinstance(d, str) else d for d in dirs] + logger.info(f"Deleting files: {dirs}") + for dir in dirs: + if not dir.exists(): + continue + shutil.rmtree(dir, ignore_errors=True) + + +def string_to_filename(s: str) -> str: + return ( + s.replace(" ", "-") + .replace("/", "-") + .replace(":", "-") + .replace(".", "-") + .replace(",", "-") + .replace(";", "-") + .replace("!", "-") + .replace("?", "-") + ) diff --git a/finetune/utils/memory_utils.py b/finetune/utils/memory_utils.py new file mode 100644 index 0000000..a7a136b --- /dev/null +++ b/finetune/utils/memory_utils.py @@ -0,0 +1,60 @@ +import gc +import torch + +from typing import Any, Dict, Union +from accelerate.logging import get_logger + +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: + memory_allocated = None + memory_reserved = None + max_memory_allocated = None + max_memory_reserved = None + + if torch.cuda.is_available(): + device = torch.cuda.current_device() + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + max_memory_allocated = torch.cuda.max_memory_allocated(device) + max_memory_reserved = torch.cuda.max_memory_reserved(device) + + elif torch.mps.is_available(): + memory_allocated = torch.mps.current_allocated_memory() + + else: + logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") + + return { + "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), + "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), + "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), + "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), + } + + +def bytes_to_gigabytes(x: int) -> float: + if x is not None: + return x / 1024**3 + + +def free_memory() -> None: + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + # TODO(aryan): handle non-cuda devices + + +def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if isinstance(x, torch.Tensor): + return x.contiguous() + elif isinstance(x, dict): + return {k: make_contiguous(v) for k, v in x.items()} + else: + return x diff --git a/finetune/utils/optimizer_utils.py b/finetune/utils/optimizer_utils.py new file mode 100644 index 0000000..bd93f9c --- /dev/null +++ b/finetune/utils/optimizer_utils.py @@ -0,0 +1,180 @@ +import inspect +import torch + +from accelerate.logging import get_logger + +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def get_optimizer( + params_to_optimize, + optimizer_name: str = "adam", + learning_rate: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.95, + beta3: float = 0.98, + epsilon: float = 1e-8, + weight_decay: float = 1e-4, + prodigy_decouple: bool = False, + prodigy_use_bias_correction: bool = False, + prodigy_safeguard_warmup: bool = False, + use_8bit: bool = False, + use_4bit: bool = False, + use_torchao: bool = False, + use_deepspeed: bool = False, + use_cpu_offload_optimizer: bool = False, + offload_gradients: bool = False, +) -> torch.optim.Optimizer: + optimizer_name = optimizer_name.lower() + + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=learning_rate, + betas=(beta1, beta2), + eps=epsilon, + weight_decay=weight_decay, + ) + + if use_8bit and use_4bit: + raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") + + if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: + try: + import torchao + + torchao.__version__ + except ImportError: + raise ImportError( + "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." + ) + + if not use_torchao and use_4bit: + raise ValueError("4-bit Optimizers are only supported with torchao.") + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy", "came"] + if optimizer_name not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." + ) + optimizer_name = "adamw" + + if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: + raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") + + if use_8bit: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if optimizer_name == "adamw": + if use_torchao: + from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit + + optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW + else: + optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "adam": + if use_torchao: + from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit + + optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam + else: + optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + init_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "beta3": beta3, + "eps": epsilon, + "weight_decay": weight_decay, + "decouple": prodigy_decouple, + "use_bias_correction": prodigy_use_bias_correction, + "safeguard_warmup": prodigy_safeguard_warmup, + } + + elif optimizer_name == "came": + try: + import came_pytorch + except ImportError: + raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") + + optimizer_class = came_pytorch.CAME + + init_kwargs = { + "lr": learning_rate, + "eps": (1e-30, 1e-16), + "betas": (beta1, beta2, beta3), + "weight_decay": weight_decay, + } + + if use_cpu_offload_optimizer: + from torchao.prototype.low_bit_optim import CPUOffloadOptimizer + + if "fused" in inspect.signature(optimizer_class.__init__).parameters: + init_kwargs.update({"fused": True}) + + optimizer = CPUOffloadOptimizer( + params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs + ) + else: + optimizer = optimizer_class(params_to_optimize, **init_kwargs) + + return optimizer + + +def gradient_norm(parameters): + norm = 0 + for param in parameters: + if param.grad is None: + continue + local_norm = param.grad.detach().data.norm(2) + norm += local_norm.item() ** 2 + norm = norm**0.5 + return norm + + +def max_gradient(parameters): + max_grad_value = float("-inf") + for param in parameters: + if param.grad is None: + continue + local_max_grad = param.grad.detach().data.abs().max() + max_grad_value = max(max_grad_value, local_max_grad.item()) + return max_grad_value diff --git a/finetune/utils/torch_utils.py b/finetune/utils/torch_utils.py new file mode 100644 index 0000000..8a74271 --- /dev/null +++ b/finetune/utils/torch_utils.py @@ -0,0 +1,52 @@ +from typing import Dict, Optional, Union, List + +import torch +from accelerate import Accelerator +from diffusers.utils.torch_utils import is_compiled_module + + +def unwrap_model(accelerator: Accelerator, model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + +def align_device_and_dtype( + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + if isinstance(x, torch.Tensor): + if device is not None: + x = x.to(device) + if dtype is not None: + x = x.to(dtype) + elif isinstance(x, dict): + if device is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + if dtype is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + return x + + +def expand_tensor_to_dims(tensor, ndim): + while len(tensor.shape) < ndim: + tensor = tensor.unsqueeze(-1) + return tensor + + +def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + """ + Casts the training parameters of the model to the specified data type. + + Args: + model: The PyTorch model whose parameters will be cast. + dtype: The data type to which the model parameters will be cast. + """ + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) \ No newline at end of file