mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
import gc
|
|
from typing import Any, Dict, Union
|
|
|
|
import torch
|
|
from accelerate.logging import get_logger
|
|
|
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
|
|
|
|
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
|
|
|
|
|
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
|
|
memory_allocated = None
|
|
memory_reserved = None
|
|
max_memory_allocated = None
|
|
max_memory_reserved = None
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.cuda.current_device()
|
|
memory_allocated = torch.cuda.memory_allocated(device)
|
|
memory_reserved = torch.cuda.memory_reserved(device)
|
|
max_memory_allocated = torch.cuda.max_memory_allocated(device)
|
|
max_memory_reserved = torch.cuda.max_memory_reserved(device)
|
|
|
|
elif torch.mps.is_available():
|
|
memory_allocated = torch.mps.current_allocated_memory()
|
|
|
|
else:
|
|
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
|
|
|
|
return {
|
|
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
|
|
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
|
|
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
|
|
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
|
|
}
|
|
|
|
|
|
def bytes_to_gigabytes(x: int) -> float:
|
|
if x is not None:
|
|
return x / 1024**3
|
|
|
|
|
|
def free_memory() -> None:
|
|
if torch.cuda.is_available():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
# TODO(aryan): handle non-cuda devices
|
|
|
|
|
|
def unload_model(model):
|
|
model.to("cpu")
|
|
|
|
|
|
def make_contiguous(
|
|
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
if isinstance(x, torch.Tensor):
|
|
return x.contiguous()
|
|
elif isinstance(x, dict):
|
|
return {k: make_contiguous(v) for k, v in x.items()}
|
|
else:
|
|
return x
|