mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Add utils
This commit is contained in:
parent
85e00a1082
commit
78f655a9a4
4
finetune/utils/__init__.py
Normal file
4
finetune/utils/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .torch_utils import *
|
||||||
|
from .optimizer_utils import *
|
||||||
|
from .memory_utils import *
|
||||||
|
from .checkpointing import *
|
53
finetune/utils/checkpointing.py
Normal file
53
finetune/utils/checkpointing.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
from ..utils.file_utils import find_files, delete_files
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_ckpt_path_to_resume_from(
|
||||||
|
resume_from_checkpoint: str | None, num_update_steps_per_epoch: int
|
||||||
|
) -> Tuple[str | None, int, int, int]:
|
||||||
|
if resume_from_checkpoint is None:
|
||||||
|
initial_global_step = 0
|
||||||
|
global_step = 0
|
||||||
|
first_epoch = 0
|
||||||
|
resume_from_checkpoint_path = None
|
||||||
|
else:
|
||||||
|
resume_from_checkpoint_path = Path(resume_from_checkpoint)
|
||||||
|
if not resume_from_checkpoint_path.exists():
|
||||||
|
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||||
|
initial_global_step = 0
|
||||||
|
global_step = 0
|
||||||
|
first_epoch = 0
|
||||||
|
resume_from_checkpoint_path = None
|
||||||
|
else:
|
||||||
|
logger.info(f"Resuming from checkpoint {resume_from_checkpoint}")
|
||||||
|
global_step = int(resume_from_checkpoint_path.name.split("-")[1])
|
||||||
|
|
||||||
|
initial_global_step = global_step
|
||||||
|
first_epoch = global_step // num_update_steps_per_epoch
|
||||||
|
|
||||||
|
return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
|
||||||
|
|
||||||
|
|
||||||
|
def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
|
||||||
|
# before saving state, check if this save would set us over the `checkpointing_limit`
|
||||||
|
if checkpointing_limit is not None:
|
||||||
|
checkpoints = find_files(output_dir, prefix="checkpoint")
|
||||||
|
|
||||||
|
# before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
|
||||||
|
if len(checkpoints) >= checkpointing_limit:
|
||||||
|
num_to_remove = len(checkpoints) - checkpointing_limit + 1
|
||||||
|
checkpoints_to_remove = checkpoints[0:num_to_remove]
|
||||||
|
delete_files(checkpoints_to_remove)
|
||||||
|
|
||||||
|
logger.info(f"Checkpointing at step {step}")
|
||||||
|
save_path = os.path.join(output_dir, f"checkpoint-{step}")
|
||||||
|
logger.info(f"Saving state to {save_path}")
|
||||||
|
return save_path
|
47
finetune/utils/file_utils.py
Normal file
47
finetune/utils/file_utils.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
|
||||||
|
|
||||||
|
def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
|
||||||
|
if not isinstance(dir, Path):
|
||||||
|
dir = Path(dir)
|
||||||
|
if not dir.exists():
|
||||||
|
return []
|
||||||
|
checkpoints = os.listdir(dir.as_posix())
|
||||||
|
checkpoints = [c for c in checkpoints if c.startswith(prefix)]
|
||||||
|
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||||
|
checkpoints = [dir / c for c in checkpoints]
|
||||||
|
return checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
|
||||||
|
if not isinstance(dirs, list):
|
||||||
|
dirs = [dirs]
|
||||||
|
dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
|
||||||
|
logger.info(f"Deleting files: {dirs}")
|
||||||
|
for dir in dirs:
|
||||||
|
if not dir.exists():
|
||||||
|
continue
|
||||||
|
shutil.rmtree(dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_filename(s: str) -> str:
|
||||||
|
return (
|
||||||
|
s.replace(" ", "-")
|
||||||
|
.replace("/", "-")
|
||||||
|
.replace(":", "-")
|
||||||
|
.replace(".", "-")
|
||||||
|
.replace(",", "-")
|
||||||
|
.replace(";", "-")
|
||||||
|
.replace("!", "-")
|
||||||
|
.replace("?", "-")
|
||||||
|
)
|
60
finetune/utils/memory_utils.py
Normal file
60
finetune/utils/memory_utils.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
|
||||||
|
memory_allocated = None
|
||||||
|
memory_reserved = None
|
||||||
|
max_memory_allocated = None
|
||||||
|
max_memory_reserved = None
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
memory_allocated = torch.cuda.memory_allocated(device)
|
||||||
|
memory_reserved = torch.cuda.memory_reserved(device)
|
||||||
|
max_memory_allocated = torch.cuda.max_memory_allocated(device)
|
||||||
|
max_memory_reserved = torch.cuda.max_memory_reserved(device)
|
||||||
|
|
||||||
|
elif torch.mps.is_available():
|
||||||
|
memory_allocated = torch.mps.current_allocated_memory()
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
|
||||||
|
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
|
||||||
|
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
|
||||||
|
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_gigabytes(x: int) -> float:
|
||||||
|
if x is not None:
|
||||||
|
return x / 1024**3
|
||||||
|
|
||||||
|
|
||||||
|
def free_memory() -> None:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
# TODO(aryan): handle non-cuda devices
|
||||||
|
|
||||||
|
|
||||||
|
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.contiguous()
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
return {k: make_contiguous(v) for k, v in x.items()}
|
||||||
|
else:
|
||||||
|
return x
|
180
finetune/utils/optimizer_utils.py
Normal file
180
finetune/utils/optimizer_utils.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import inspect
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer(
|
||||||
|
params_to_optimize,
|
||||||
|
optimizer_name: str = "adam",
|
||||||
|
learning_rate: float = 1e-3,
|
||||||
|
beta1: float = 0.9,
|
||||||
|
beta2: float = 0.95,
|
||||||
|
beta3: float = 0.98,
|
||||||
|
epsilon: float = 1e-8,
|
||||||
|
weight_decay: float = 1e-4,
|
||||||
|
prodigy_decouple: bool = False,
|
||||||
|
prodigy_use_bias_correction: bool = False,
|
||||||
|
prodigy_safeguard_warmup: bool = False,
|
||||||
|
use_8bit: bool = False,
|
||||||
|
use_4bit: bool = False,
|
||||||
|
use_torchao: bool = False,
|
||||||
|
use_deepspeed: bool = False,
|
||||||
|
use_cpu_offload_optimizer: bool = False,
|
||||||
|
offload_gradients: bool = False,
|
||||||
|
) -> torch.optim.Optimizer:
|
||||||
|
optimizer_name = optimizer_name.lower()
|
||||||
|
|
||||||
|
# Use DeepSpeed optimzer
|
||||||
|
if use_deepspeed:
|
||||||
|
from accelerate.utils import DummyOptim
|
||||||
|
|
||||||
|
return DummyOptim(
|
||||||
|
params_to_optimize,
|
||||||
|
lr=learning_rate,
|
||||||
|
betas=(beta1, beta2),
|
||||||
|
eps=epsilon,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_8bit and use_4bit:
|
||||||
|
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
|
||||||
|
|
||||||
|
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
|
||||||
|
try:
|
||||||
|
import torchao
|
||||||
|
|
||||||
|
torchao.__version__
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not use_torchao and use_4bit:
|
||||||
|
raise ValueError("4-bit Optimizers are only supported with torchao.")
|
||||||
|
|
||||||
|
# Optimizer creation
|
||||||
|
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
|
||||||
|
if optimizer_name not in supported_optimizers:
|
||||||
|
logger.warning(
|
||||||
|
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
|
||||||
|
)
|
||||||
|
optimizer_name = "adamw"
|
||||||
|
|
||||||
|
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
||||||
|
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
|
||||||
|
|
||||||
|
if use_8bit:
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_name == "adamw":
|
||||||
|
if use_torchao:
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
||||||
|
|
||||||
|
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
||||||
|
else:
|
||||||
|
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
||||||
|
|
||||||
|
init_kwargs = {
|
||||||
|
"betas": (beta1, beta2),
|
||||||
|
"eps": epsilon,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif optimizer_name == "adam":
|
||||||
|
if use_torchao:
|
||||||
|
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
|
||||||
|
|
||||||
|
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
|
||||||
|
else:
|
||||||
|
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
|
||||||
|
|
||||||
|
init_kwargs = {
|
||||||
|
"betas": (beta1, beta2),
|
||||||
|
"eps": epsilon,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif optimizer_name == "prodigy":
|
||||||
|
try:
|
||||||
|
import prodigyopt
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
||||||
|
|
||||||
|
optimizer_class = prodigyopt.Prodigy
|
||||||
|
|
||||||
|
if learning_rate <= 0.1:
|
||||||
|
logger.warning(
|
||||||
|
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
init_kwargs = {
|
||||||
|
"lr": learning_rate,
|
||||||
|
"betas": (beta1, beta2),
|
||||||
|
"beta3": beta3,
|
||||||
|
"eps": epsilon,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
"decouple": prodigy_decouple,
|
||||||
|
"use_bias_correction": prodigy_use_bias_correction,
|
||||||
|
"safeguard_warmup": prodigy_safeguard_warmup,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif optimizer_name == "came":
|
||||||
|
try:
|
||||||
|
import came_pytorch
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
|
||||||
|
|
||||||
|
optimizer_class = came_pytorch.CAME
|
||||||
|
|
||||||
|
init_kwargs = {
|
||||||
|
"lr": learning_rate,
|
||||||
|
"eps": (1e-30, 1e-16),
|
||||||
|
"betas": (beta1, beta2, beta3),
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
}
|
||||||
|
|
||||||
|
if use_cpu_offload_optimizer:
|
||||||
|
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
|
||||||
|
|
||||||
|
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
|
||||||
|
init_kwargs.update({"fused": True})
|
||||||
|
|
||||||
|
optimizer = CPUOffloadOptimizer(
|
||||||
|
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_norm(parameters):
|
||||||
|
norm = 0
|
||||||
|
for param in parameters:
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
local_norm = param.grad.detach().data.norm(2)
|
||||||
|
norm += local_norm.item() ** 2
|
||||||
|
norm = norm**0.5
|
||||||
|
return norm
|
||||||
|
|
||||||
|
|
||||||
|
def max_gradient(parameters):
|
||||||
|
max_grad_value = float("-inf")
|
||||||
|
for param in parameters:
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
local_max_grad = param.grad.detach().data.abs().max()
|
||||||
|
max_grad_value = max(max_grad_value, local_max_grad.item())
|
||||||
|
return max_grad_value
|
52
finetune/utils/torch_utils.py
Normal file
52
finetune/utils/torch_utils.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from typing import Dict, Optional, Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from diffusers.utils.torch_utils import is_compiled_module
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_model(accelerator: Accelerator, model):
|
||||||
|
model = accelerator.unwrap_model(model)
|
||||||
|
model = model._orig_mod if is_compiled_module(model) else model
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def align_device_and_dtype(
|
||||||
|
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
if device is not None:
|
||||||
|
x = x.to(device)
|
||||||
|
if dtype is not None:
|
||||||
|
x = x.to(dtype)
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
if device is not None:
|
||||||
|
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
||||||
|
if dtype is not None:
|
||||||
|
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def expand_tensor_to_dims(tensor, ndim):
|
||||||
|
while len(tensor.shape) < ndim:
|
||||||
|
tensor = tensor.unsqueeze(-1)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
||||||
|
"""
|
||||||
|
Casts the training parameters of the model to the specified data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model whose parameters will be cast.
|
||||||
|
dtype: The data type to which the model parameters will be cast.
|
||||||
|
"""
|
||||||
|
if not isinstance(model, list):
|
||||||
|
model = [model]
|
||||||
|
for m in model:
|
||||||
|
for param in m.parameters():
|
||||||
|
# only upcast trainable parameters into fp32
|
||||||
|
if param.requires_grad:
|
||||||
|
param.data = param.to(dtype)
|
Loading…
x
Reference in New Issue
Block a user