mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
feat(trainer): add validation functionality to Trainer class
Add validation capabilities to the Trainer class including: - Support for validating images and videos during training - Periodic validation based on validation_steps parameter - Artifact logging to wandb for validation results - Memory tracking during validation process
This commit is contained in:
parent
6971364591
commit
fa4659fb2c
@ -1,15 +1,18 @@
|
|||||||
import torch
|
import os
|
||||||
import logging
|
import logging
|
||||||
import transformers
|
|
||||||
import diffusers
|
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
import diffusers
|
||||||
|
import wandb
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List, Tuple
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@ -23,6 +26,7 @@ from accelerate.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.utils.export_utils import export_to_video
|
||||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||||
|
|
||||||
from finetune.schemas import Args, State, Components
|
from finetune.schemas import Args, State, Components
|
||||||
@ -37,8 +41,14 @@ from finetune.utils import (
|
|||||||
get_intermediate_ckpt_path,
|
get_intermediate_ckpt_path,
|
||||||
get_latest_ckpt_path_to_resume_from,
|
get_latest_ckpt_path_to_resume_from,
|
||||||
get_intermediate_ckpt_path,
|
get_intermediate_ckpt_path,
|
||||||
|
|
||||||
|
string_to_filename
|
||||||
)
|
)
|
||||||
from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler
|
from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler
|
||||||
|
from finetune.datasets.utils import (
|
||||||
|
load_prompts, load_images, load_videos,
|
||||||
|
preprocess_image_with_resize, preprocess_video_with_resize
|
||||||
|
)
|
||||||
|
|
||||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||||
|
|
||||||
@ -263,6 +273,24 @@ class Trainer:
|
|||||||
# Afterwards we recalculate our number of training epochs
|
# Afterwards we recalculate our number of training epochs
|
||||||
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
|
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
|
||||||
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
||||||
|
|
||||||
|
def prepare_for_validation(self):
|
||||||
|
self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')]
|
||||||
|
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
|
||||||
|
|
||||||
|
if self.args.validation_images is not None:
|
||||||
|
validation_images = load_images(self.args.validation_dir / self.args.validation_images)
|
||||||
|
else:
|
||||||
|
validation_images = [None] * len(validation_prompts)
|
||||||
|
|
||||||
|
if self.args.validation_videos is not None:
|
||||||
|
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos)
|
||||||
|
else:
|
||||||
|
validation_videos = [None] * len(validation_prompts)
|
||||||
|
|
||||||
|
self.state.validation_prompts = validation_prompts
|
||||||
|
self.state.validation_images = validation_images
|
||||||
|
self.state.validation_videos = validation_videos
|
||||||
|
|
||||||
def prepare_trackers(self) -> None:
|
def prepare_trackers(self) -> None:
|
||||||
logger.info("Initializing trackers")
|
logger.info("Initializing trackers")
|
||||||
@ -319,6 +347,7 @@ class Trainer:
|
|||||||
generator = torch.Generator(device=accelerator.device)
|
generator = torch.Generator(device=accelerator.device)
|
||||||
if self.args.seed is not None:
|
if self.args.seed is not None:
|
||||||
generator = generator.manual_seed(self.args.seed)
|
generator = generator.manual_seed(self.args.seed)
|
||||||
|
self.state.generator = generator
|
||||||
|
|
||||||
for epoch in range(first_epoch, self.args.train_epochs):
|
for epoch in range(first_epoch, self.args.train_epochs):
|
||||||
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
||||||
@ -362,7 +391,7 @@ class Trainer:
|
|||||||
|
|
||||||
# Maybe run validation
|
# Maybe run validation
|
||||||
should_run_validation = (
|
should_run_validation = (
|
||||||
self.args.validation_steps is not None
|
self.args.do_validation
|
||||||
and global_step % self.args.validation_steps == 0
|
and global_step % self.args.validation_steps == 0
|
||||||
)
|
)
|
||||||
if should_run_validation:
|
if should_run_validation:
|
||||||
@ -381,7 +410,8 @@ class Trainer:
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
self.__maybe_save_checkpoint(global_step, must_save=True)
|
self.__maybe_save_checkpoint(global_step, must_save=True)
|
||||||
self.validate(global_step)
|
if self.args.do_validation:
|
||||||
|
self.validate(global_step)
|
||||||
|
|
||||||
del self.components
|
del self.components
|
||||||
free_memory()
|
free_memory()
|
||||||
@ -390,12 +420,124 @@ class Trainer:
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
def validate(self, step: int) -> None:
|
||||||
|
logger.info("Starting validation")
|
||||||
|
|
||||||
|
accelerator = self.accelerator
|
||||||
|
num_validation_samples = len(self.state.validation_prompts)
|
||||||
|
|
||||||
|
if num_validation_samples == 0:
|
||||||
|
logger.warning("No validation samples found. Skipping validation.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.components.transformer.eval()
|
||||||
|
|
||||||
|
memory_statistics = get_memory_statistics()
|
||||||
|
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
|
||||||
|
|
||||||
|
all_processes_artifacts = []
|
||||||
|
for i in range(num_validation_samples):
|
||||||
|
# Skip current validation on all processes but one
|
||||||
|
if i % accelerator.num_processes != accelerator.process_index:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prompt = self.state.validation_prompts[i]
|
||||||
|
image = self.state.validation_images[i]
|
||||||
|
video = self.state.validation_videos[i]
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
image = preprocess_image_with_resize(
|
||||||
|
image, self.state.gen_height, self.state.gen_width
|
||||||
|
)
|
||||||
|
# Convert image tensor (C, H, W) to PIL images
|
||||||
|
image = image.to(torch.uint8)
|
||||||
|
image = image.permute(1, 2, 0).cpu().numpy()
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
|
||||||
|
if video is not None:
|
||||||
|
video = preprocess_video_with_resize(
|
||||||
|
video, self.state.gen_frames, self.state.gen_height, self.state.gen_width
|
||||||
|
)
|
||||||
|
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||||
|
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
||||||
|
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||||
|
main_process_only=False,
|
||||||
|
)
|
||||||
|
validation_artifacts = self.validation_step({
|
||||||
|
"prompt": prompt,
|
||||||
|
"image": image,
|
||||||
|
"video": video
|
||||||
|
})
|
||||||
|
prompt_filename = string_to_filename(prompt)[:25]
|
||||||
|
artifacts = {
|
||||||
|
"image": {"type": "image", "value": image},
|
||||||
|
"video": {"type": "video", "value": video},
|
||||||
|
}
|
||||||
|
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
||||||
|
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
|
||||||
|
logger.debug(
|
||||||
|
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
||||||
|
main_process_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in list(artifacts.items()):
|
||||||
|
artifact_type = value["type"]
|
||||||
|
artifact_value = value["value"]
|
||||||
|
if artifact_type not in ["image", "video"] or artifact_value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
extension = "png" if artifact_type == "image" else "mp4"
|
||||||
|
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
|
||||||
|
validation_path = self.args.output_dir / "validation_res"
|
||||||
|
validation_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
filename = str(validation_path / filename)
|
||||||
|
|
||||||
|
if artifact_type == "image":
|
||||||
|
logger.debug(f"Saving image to {filename}")
|
||||||
|
artifact_value.save(filename)
|
||||||
|
artifact_value = wandb.Image(filename)
|
||||||
|
elif artifact_type == "video":
|
||||||
|
logger.debug(f"Saving video to {filename}")
|
||||||
|
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
|
||||||
|
artifact_value = wandb.Video(filename, caption=prompt)
|
||||||
|
|
||||||
|
all_processes_artifacts.append(artifact_value)
|
||||||
|
|
||||||
|
all_artifacts = gather_object(all_processes_artifacts)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
tracker_key = "validation"
|
||||||
|
for tracker in accelerator.trackers:
|
||||||
|
if tracker.name == "wandb":
|
||||||
|
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
|
||||||
|
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
|
||||||
|
tracker.log(
|
||||||
|
{
|
||||||
|
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
|
||||||
|
},
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
free_memory()
|
||||||
|
memory_statistics = get_memory_statistics()
|
||||||
|
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||||
|
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||||
|
|
||||||
|
self.components.transformer.train()
|
||||||
|
|
||||||
def fit(self):
|
def fit(self):
|
||||||
self.prepare_dataset()
|
self.prepare_dataset()
|
||||||
self.prepare_models()
|
self.prepare_models()
|
||||||
self.prepare_trainable_parameters()
|
self.prepare_trainable_parameters()
|
||||||
self.prepare_optimizer()
|
self.prepare_optimizer()
|
||||||
self.prepare_for_training()
|
self.prepare_for_training()
|
||||||
|
if self.args.do_validation:
|
||||||
|
self.prepare_for_validation()
|
||||||
self.prepare_trackers()
|
self.prepare_trackers()
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
@ -412,7 +554,7 @@ class Trainer:
|
|||||||
def compute_loss(self, batch) -> torch.Tensor:
|
def compute_loss(self, batch) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __get_training_dtype(self) -> torch.dtype:
|
def __get_training_dtype(self) -> torch.dtype:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user