mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Add utils
This commit is contained in:
parent
85e00a1082
commit
78f655a9a4
4
finetune/utils/__init__.py
Normal file
4
finetune/utils/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .torch_utils import *
|
||||
from .optimizer_utils import *
|
||||
from .memory_utils import *
|
||||
from .checkpointing import *
|
53
finetune/utils/checkpointing.py
Normal file
53
finetune/utils/checkpointing.py
Normal file
@ -0,0 +1,53 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
from ..utils.file_utils import find_files, delete_files
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def get_latest_ckpt_path_to_resume_from(
|
||||
resume_from_checkpoint: str | None, num_update_steps_per_epoch: int
|
||||
) -> Tuple[str | None, int, int, int]:
|
||||
if resume_from_checkpoint is None:
|
||||
initial_global_step = 0
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
resume_from_checkpoint_path = None
|
||||
else:
|
||||
resume_from_checkpoint_path = Path(resume_from_checkpoint)
|
||||
if not resume_from_checkpoint_path.exists():
|
||||
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||
initial_global_step = 0
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
resume_from_checkpoint_path = None
|
||||
else:
|
||||
logger.info(f"Resuming from checkpoint {resume_from_checkpoint}")
|
||||
global_step = int(resume_from_checkpoint_path.name.split("-")[1])
|
||||
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
|
||||
return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
|
||||
|
||||
|
||||
def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
|
||||
# before saving state, check if this save would set us over the `checkpointing_limit`
|
||||
if checkpointing_limit is not None:
|
||||
checkpoints = find_files(output_dir, prefix="checkpoint")
|
||||
|
||||
# before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= checkpointing_limit:
|
||||
num_to_remove = len(checkpoints) - checkpointing_limit + 1
|
||||
checkpoints_to_remove = checkpoints[0:num_to_remove]
|
||||
delete_files(checkpoints_to_remove)
|
||||
|
||||
logger.info(f"Checkpointing at step {step}")
|
||||
save_path = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
logger.info(f"Saving state to {save_path}")
|
||||
return save_path
|
47
finetune/utils/file_utils.py
Normal file
47
finetune/utils/file_utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
from accelerate.logging import get_logger
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
|
||||
if not isinstance(dir, Path):
|
||||
dir = Path(dir)
|
||||
if not dir.exists():
|
||||
return []
|
||||
checkpoints = os.listdir(dir.as_posix())
|
||||
checkpoints = [c for c in checkpoints if c.startswith(prefix)]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
checkpoints = [dir / c for c in checkpoints]
|
||||
return checkpoints
|
||||
|
||||
|
||||
def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
|
||||
if not isinstance(dirs, list):
|
||||
dirs = [dirs]
|
||||
dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
|
||||
logger.info(f"Deleting files: {dirs}")
|
||||
for dir in dirs:
|
||||
if not dir.exists():
|
||||
continue
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
|
||||
|
||||
def string_to_filename(s: str) -> str:
|
||||
return (
|
||||
s.replace(" ", "-")
|
||||
.replace("/", "-")
|
||||
.replace(":", "-")
|
||||
.replace(".", "-")
|
||||
.replace(",", "-")
|
||||
.replace(";", "-")
|
||||
.replace("!", "-")
|
||||
.replace("?", "-")
|
||||
)
|
60
finetune/utils/memory_utils.py
Normal file
60
finetune/utils/memory_utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
import gc
|
||||
import torch
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
|
||||
memory_allocated = None
|
||||
memory_reserved = None
|
||||
max_memory_allocated = None
|
||||
max_memory_reserved = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.cuda.current_device()
|
||||
memory_allocated = torch.cuda.memory_allocated(device)
|
||||
memory_reserved = torch.cuda.memory_reserved(device)
|
||||
max_memory_allocated = torch.cuda.max_memory_allocated(device)
|
||||
max_memory_reserved = torch.cuda.max_memory_reserved(device)
|
||||
|
||||
elif torch.mps.is_available():
|
||||
memory_allocated = torch.mps.current_allocated_memory()
|
||||
|
||||
else:
|
||||
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
|
||||
|
||||
return {
|
||||
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
|
||||
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
|
||||
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
|
||||
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
|
||||
}
|
||||
|
||||
|
||||
def bytes_to_gigabytes(x: int) -> float:
|
||||
if x is not None:
|
||||
return x / 1024**3
|
||||
|
||||
|
||||
def free_memory() -> None:
|
||||
if torch.cuda.is_available():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
# TODO(aryan): handle non-cuda devices
|
||||
|
||||
|
||||
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.contiguous()
|
||||
elif isinstance(x, dict):
|
||||
return {k: make_contiguous(v) for k, v in x.items()}
|
||||
else:
|
||||
return x
|
180
finetune/utils/optimizer_utils.py
Normal file
180
finetune/utils/optimizer_utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
import inspect
|
||||
import torch
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
params_to_optimize,
|
||||
optimizer_name: str = "adam",
|
||||
learning_rate: float = 1e-3,
|
||||
beta1: float = 0.9,
|
||||
beta2: float = 0.95,
|
||||
beta3: float = 0.98,
|
||||
epsilon: float = 1e-8,
|
||||
weight_decay: float = 1e-4,
|
||||
prodigy_decouple: bool = False,
|
||||
prodigy_use_bias_correction: bool = False,
|
||||
prodigy_safeguard_warmup: bool = False,
|
||||
use_8bit: bool = False,
|
||||
use_4bit: bool = False,
|
||||
use_torchao: bool = False,
|
||||
use_deepspeed: bool = False,
|
||||
use_cpu_offload_optimizer: bool = False,
|
||||
offload_gradients: bool = False,
|
||||
) -> torch.optim.Optimizer:
|
||||
optimizer_name = optimizer_name.lower()
|
||||
|
||||
# Use DeepSpeed optimzer
|
||||
if use_deepspeed:
|
||||
from accelerate.utils import DummyOptim
|
||||
|
||||
return DummyOptim(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(beta1, beta2),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
if use_8bit and use_4bit:
|
||||
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
|
||||
|
||||
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
|
||||
try:
|
||||
import torchao
|
||||
|
||||
torchao.__version__
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
|
||||
)
|
||||
|
||||
if not use_torchao and use_4bit:
|
||||
raise ValueError("4-bit Optimizers are only supported with torchao.")
|
||||
|
||||
# Optimizer creation
|
||||
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
|
||||
if optimizer_name not in supported_optimizers:
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
|
||||
)
|
||||
optimizer_name = "adamw"
|
||||
|
||||
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
||||
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
|
||||
|
||||
if use_8bit:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if optimizer_name == "adamw":
|
||||
if use_torchao:
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
||||
|
||||
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
||||
else:
|
||||
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
||||
|
||||
init_kwargs = {
|
||||
"betas": (beta1, beta2),
|
||||
"eps": epsilon,
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
|
||||
elif optimizer_name == "adam":
|
||||
if use_torchao:
|
||||
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
|
||||
|
||||
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
|
||||
else:
|
||||
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
|
||||
|
||||
init_kwargs = {
|
||||
"betas": (beta1, beta2),
|
||||
"eps": epsilon,
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
|
||||
elif optimizer_name == "prodigy":
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if learning_rate <= 0.1:
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
|
||||
init_kwargs = {
|
||||
"lr": learning_rate,
|
||||
"betas": (beta1, beta2),
|
||||
"beta3": beta3,
|
||||
"eps": epsilon,
|
||||
"weight_decay": weight_decay,
|
||||
"decouple": prodigy_decouple,
|
||||
"use_bias_correction": prodigy_use_bias_correction,
|
||||
"safeguard_warmup": prodigy_safeguard_warmup,
|
||||
}
|
||||
|
||||
elif optimizer_name == "came":
|
||||
try:
|
||||
import came_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
|
||||
|
||||
optimizer_class = came_pytorch.CAME
|
||||
|
||||
init_kwargs = {
|
||||
"lr": learning_rate,
|
||||
"eps": (1e-30, 1e-16),
|
||||
"betas": (beta1, beta2, beta3),
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
|
||||
if use_cpu_offload_optimizer:
|
||||
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
|
||||
|
||||
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
|
||||
init_kwargs.update({"fused": True})
|
||||
|
||||
optimizer = CPUOffloadOptimizer(
|
||||
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
|
||||
)
|
||||
else:
|
||||
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def gradient_norm(parameters):
|
||||
norm = 0
|
||||
for param in parameters:
|
||||
if param.grad is None:
|
||||
continue
|
||||
local_norm = param.grad.detach().data.norm(2)
|
||||
norm += local_norm.item() ** 2
|
||||
norm = norm**0.5
|
||||
return norm
|
||||
|
||||
|
||||
def max_gradient(parameters):
|
||||
max_grad_value = float("-inf")
|
||||
for param in parameters:
|
||||
if param.grad is None:
|
||||
continue
|
||||
local_max_grad = param.grad.detach().data.abs().max()
|
||||
max_grad_value = max(max_grad_value, local_max_grad.item())
|
||||
return max_grad_value
|
52
finetune/utils/torch_utils.py
Normal file
52
finetune/utils/torch_utils.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Dict, Optional, Union, List
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
def unwrap_model(accelerator: Accelerator, model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
|
||||
def align_device_and_dtype(
|
||||
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if device is not None:
|
||||
x = x.to(device)
|
||||
if dtype is not None:
|
||||
x = x.to(dtype)
|
||||
elif isinstance(x, dict):
|
||||
if device is not None:
|
||||
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
||||
if dtype is not None:
|
||||
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
||||
return x
|
||||
|
||||
|
||||
def expand_tensor_to_dims(tensor, ndim):
|
||||
while len(tensor.shape) < ndim:
|
||||
tensor = tensor.unsqueeze(-1)
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
||||
"""
|
||||
Casts the training parameters of the model to the specified data type.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model whose parameters will be cast.
|
||||
dtype: The data type to which the model parameters will be cast.
|
||||
"""
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
Loading…
x
Reference in New Issue
Block a user