feat(trainer): add validation functionality to Trainer class

Add validation capabilities to the Trainer class including:
- Support for validating images and videos during training
- Periodic validation based on validation_steps parameter
- Artifact logging to wandb for validation results
- Memory tracking during validation process
This commit is contained in:
OleehyO 2024-12-30 06:51:03 +00:00
parent 6971364591
commit fa4659fb2c

View File

@ -1,15 +1,18 @@
import torch
import os
import logging
import transformers
import diffusers
import math
import json
import multiprocessing
import torch
import transformers
import diffusers
import wandb
from datetime import timedelta
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Any, List
from typing import Dict, Any, List, Tuple
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from accelerate.logging import get_logger
@ -23,6 +26,7 @@ from accelerate.utils import (
)
from diffusers.optimization import get_scheduler
from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from finetune.schemas import Args, State, Components
@ -37,8 +41,14 @@ from finetune.utils import (
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
string_to_filename
)
from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler
from finetune.datasets.utils import (
load_prompts, load_images, load_videos,
preprocess_image_with_resize, preprocess_video_with_resize
)
from finetune.constants import LOG_NAME, LOG_LEVEL
@ -263,6 +273,24 @@ class Trainer:
# Afterwards we recalculate our number of training epochs
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
def prepare_for_validation(self):
self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')]
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
if self.args.validation_images is not None:
validation_images = load_images(self.args.validation_dir / self.args.validation_images)
else:
validation_images = [None] * len(validation_prompts)
if self.args.validation_videos is not None:
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos)
else:
validation_videos = [None] * len(validation_prompts)
self.state.validation_prompts = validation_prompts
self.state.validation_images = validation_images
self.state.validation_videos = validation_videos
def prepare_trackers(self) -> None:
logger.info("Initializing trackers")
@ -319,6 +347,7 @@ class Trainer:
generator = torch.Generator(device=accelerator.device)
if self.args.seed is not None:
generator = generator.manual_seed(self.args.seed)
self.state.generator = generator
for epoch in range(first_epoch, self.args.train_epochs):
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
@ -362,7 +391,7 @@ class Trainer:
# Maybe run validation
should_run_validation = (
self.args.validation_steps is not None
self.args.do_validation
and global_step % self.args.validation_steps == 0
)
if should_run_validation:
@ -381,7 +410,8 @@ class Trainer:
accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True)
self.validate(global_step)
if self.args.do_validation:
self.validate(global_step)
del self.components
free_memory()
@ -390,12 +420,124 @@ class Trainer:
accelerator.end_training()
def validate(self, step: int) -> None:
logger.info("Starting validation")
accelerator = self.accelerator
num_validation_samples = len(self.state.validation_prompts)
if num_validation_samples == 0:
logger.warning("No validation samples found. Skipping validation.")
return
self.components.transformer.eval()
memory_statistics = get_memory_statistics()
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
all_processes_artifacts = []
for i in range(num_validation_samples):
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue
prompt = self.state.validation_prompts[i]
image = self.state.validation_images[i]
video = self.state.validation_videos[i]
if image is not None:
image = preprocess_image_with_resize(
image, self.state.gen_height, self.state.gen_width
)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
image = Image.fromarray(image)
if video is not None:
video = preprocess_video_with_resize(
video, self.state.gen_frames, self.state.gen_height, self.state.gen_width
)
# Convert video tensor (F, C, H, W) to list of PIL images
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.validation_step({
"prompt": prompt,
"image": image,
"video": video
})
prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
"video": {"type": "video", "value": video},
}
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
logger.debug(
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
main_process_only=False,
)
for key, value in list(artifacts.items()):
artifact_type = value["type"]
artifact_value = value["value"]
if artifact_type not in ["image", "video"] or artifact_value is None:
continue
extension = "png" if artifact_type == "image" else "mp4"
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
validation_path = self.args.output_dir / "validation_res"
validation_path.mkdir(parents=True, exist_ok=True)
filename = str(validation_path / filename)
if artifact_type == "image":
logger.debug(f"Saving image to {filename}")
artifact_value.save(filename)
artifact_value = wandb.Image(filename)
elif artifact_type == "video":
logger.debug(f"Saving video to {filename}")
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
artifact_value = wandb.Video(filename, caption=prompt)
all_processes_artifacts.append(artifact_value)
all_artifacts = gather_object(all_processes_artifacts)
if accelerator.is_main_process:
tracker_key = "validation"
for tracker in accelerator.trackers:
if tracker.name == "wandb":
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
tracker.log(
{
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
},
step=step,
)
accelerator.wait_for_everyone()
free_memory()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
self.components.transformer.train()
def fit(self):
self.prepare_dataset()
self.prepare_models()
self.prepare_trainable_parameters()
self.prepare_optimizer()
self.prepare_for_training()
if self.args.do_validation:
self.prepare_for_validation()
self.prepare_trackers()
self.train()
@ -412,7 +554,7 @@ class Trainer:
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
def validate(self) -> None:
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
raise NotImplementedError
def __get_training_dtype(self) -> torch.dtype: