mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
192 lines
5.8 KiB
Python
192 lines
5.8 KiB
Python
import inspect
|
|
|
|
import torch
|
|
from accelerate.logging import get_logger
|
|
|
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
|
|
|
|
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
|
|
|
|
|
def get_optimizer(
|
|
params_to_optimize,
|
|
optimizer_name: str = "adam",
|
|
learning_rate: float = 1e-3,
|
|
beta1: float = 0.9,
|
|
beta2: float = 0.95,
|
|
beta3: float = 0.98,
|
|
epsilon: float = 1e-8,
|
|
weight_decay: float = 1e-4,
|
|
prodigy_decouple: bool = False,
|
|
prodigy_use_bias_correction: bool = False,
|
|
prodigy_safeguard_warmup: bool = False,
|
|
use_8bit: bool = False,
|
|
use_4bit: bool = False,
|
|
use_torchao: bool = False,
|
|
use_deepspeed: bool = False,
|
|
use_cpu_offload_optimizer: bool = False,
|
|
offload_gradients: bool = False,
|
|
) -> torch.optim.Optimizer:
|
|
optimizer_name = optimizer_name.lower()
|
|
|
|
# Use DeepSpeed optimzer
|
|
if use_deepspeed:
|
|
from accelerate.utils import DummyOptim
|
|
|
|
return DummyOptim(
|
|
params_to_optimize,
|
|
lr=learning_rate,
|
|
betas=(beta1, beta2),
|
|
eps=epsilon,
|
|
weight_decay=weight_decay,
|
|
)
|
|
|
|
if use_8bit and use_4bit:
|
|
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
|
|
|
|
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
|
|
try:
|
|
import torchao
|
|
|
|
torchao.__version__
|
|
except ImportError:
|
|
raise ImportError(
|
|
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
|
|
)
|
|
|
|
if not use_torchao and use_4bit:
|
|
raise ValueError("4-bit Optimizers are only supported with torchao.")
|
|
|
|
# Optimizer creation
|
|
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
|
|
if optimizer_name not in supported_optimizers:
|
|
logger.warning(
|
|
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
|
|
)
|
|
optimizer_name = "adamw"
|
|
|
|
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
|
raise ValueError(
|
|
"`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
|
|
)
|
|
|
|
if use_8bit:
|
|
try:
|
|
import bitsandbytes as bnb
|
|
except ImportError:
|
|
raise ImportError(
|
|
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
|
)
|
|
|
|
if optimizer_name == "adamw":
|
|
if use_torchao:
|
|
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
|
|
|
optimizer_class = (
|
|
AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
|
)
|
|
else:
|
|
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
|
|
|
init_kwargs = {
|
|
"betas": (beta1, beta2),
|
|
"eps": epsilon,
|
|
"weight_decay": weight_decay,
|
|
}
|
|
|
|
elif optimizer_name == "adam":
|
|
if use_torchao:
|
|
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
|
|
|
|
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
|
|
else:
|
|
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
|
|
|
|
init_kwargs = {
|
|
"betas": (beta1, beta2),
|
|
"eps": epsilon,
|
|
"weight_decay": weight_decay,
|
|
}
|
|
|
|
elif optimizer_name == "prodigy":
|
|
try:
|
|
import prodigyopt
|
|
except ImportError:
|
|
raise ImportError(
|
|
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
|
|
)
|
|
|
|
optimizer_class = prodigyopt.Prodigy
|
|
|
|
if learning_rate <= 0.1:
|
|
logger.warning(
|
|
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
|
)
|
|
|
|
init_kwargs = {
|
|
"lr": learning_rate,
|
|
"betas": (beta1, beta2),
|
|
"beta3": beta3,
|
|
"eps": epsilon,
|
|
"weight_decay": weight_decay,
|
|
"decouple": prodigy_decouple,
|
|
"use_bias_correction": prodigy_use_bias_correction,
|
|
"safeguard_warmup": prodigy_safeguard_warmup,
|
|
}
|
|
|
|
elif optimizer_name == "came":
|
|
try:
|
|
import came_pytorch
|
|
except ImportError:
|
|
raise ImportError(
|
|
"To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
|
|
)
|
|
|
|
optimizer_class = came_pytorch.CAME
|
|
|
|
init_kwargs = {
|
|
"lr": learning_rate,
|
|
"eps": (1e-30, 1e-16),
|
|
"betas": (beta1, beta2, beta3),
|
|
"weight_decay": weight_decay,
|
|
}
|
|
|
|
if use_cpu_offload_optimizer:
|
|
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
|
|
|
|
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
|
|
init_kwargs.update({"fused": True})
|
|
|
|
optimizer = CPUOffloadOptimizer(
|
|
params_to_optimize,
|
|
optimizer_class=optimizer_class,
|
|
offload_gradients=offload_gradients,
|
|
**init_kwargs,
|
|
)
|
|
else:
|
|
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
|
|
|
return optimizer
|
|
|
|
|
|
def gradient_norm(parameters):
|
|
norm = 0
|
|
for param in parameters:
|
|
if param.grad is None:
|
|
continue
|
|
local_norm = param.grad.detach().data.norm(2)
|
|
norm += local_norm.item() ** 2
|
|
norm = norm**0.5
|
|
return norm
|
|
|
|
|
|
def max_gradient(parameters):
|
|
max_grad_value = float("-inf")
|
|
for param in parameters:
|
|
if param.grad is None:
|
|
continue
|
|
local_max_grad = param.grad.detach().data.abs().max()
|
|
max_grad_value = max(max_grad_value, local_max_grad.item())
|
|
return max_grad_value
|