mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat(trainer): add validation functionality to Trainer class
Add validation capabilities to the Trainer class including: - Support for validating images and videos during training - Periodic validation based on validation_steps parameter - Artifact logging to wandb for validation results - Memory tracking during validation process
This commit is contained in:
parent
6971364591
commit
fa4659fb2c
@ -1,15 +1,18 @@
|
||||
import torch
|
||||
import os
|
||||
import logging
|
||||
import transformers
|
||||
import diffusers
|
||||
import math
|
||||
import json
|
||||
import multiprocessing
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import diffusers
|
||||
import wandb
|
||||
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from PIL import Image
|
||||
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from accelerate.logging import get_logger
|
||||
@ -23,6 +26,7 @@ from accelerate.utils import (
|
||||
)
|
||||
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils.export_utils import export_to_video
|
||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||
|
||||
from finetune.schemas import Args, State, Components
|
||||
@ -37,8 +41,14 @@ from finetune.utils import (
|
||||
get_intermediate_ckpt_path,
|
||||
get_latest_ckpt_path_to_resume_from,
|
||||
get_intermediate_ckpt_path,
|
||||
|
||||
string_to_filename
|
||||
)
|
||||
from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler
|
||||
from finetune.datasets.utils import (
|
||||
load_prompts, load_images, load_videos,
|
||||
preprocess_image_with_resize, preprocess_video_with_resize
|
||||
)
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
@ -263,6 +273,24 @@ class Trainer:
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
|
||||
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
||||
|
||||
def prepare_for_validation(self):
|
||||
self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')]
|
||||
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
|
||||
|
||||
if self.args.validation_images is not None:
|
||||
validation_images = load_images(self.args.validation_dir / self.args.validation_images)
|
||||
else:
|
||||
validation_images = [None] * len(validation_prompts)
|
||||
|
||||
if self.args.validation_videos is not None:
|
||||
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos)
|
||||
else:
|
||||
validation_videos = [None] * len(validation_prompts)
|
||||
|
||||
self.state.validation_prompts = validation_prompts
|
||||
self.state.validation_images = validation_images
|
||||
self.state.validation_videos = validation_videos
|
||||
|
||||
def prepare_trackers(self) -> None:
|
||||
logger.info("Initializing trackers")
|
||||
@ -319,6 +347,7 @@ class Trainer:
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if self.args.seed is not None:
|
||||
generator = generator.manual_seed(self.args.seed)
|
||||
self.state.generator = generator
|
||||
|
||||
for epoch in range(first_epoch, self.args.train_epochs):
|
||||
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
||||
@ -362,7 +391,7 @@ class Trainer:
|
||||
|
||||
# Maybe run validation
|
||||
should_run_validation = (
|
||||
self.args.validation_steps is not None
|
||||
self.args.do_validation
|
||||
and global_step % self.args.validation_steps == 0
|
||||
)
|
||||
if should_run_validation:
|
||||
@ -381,7 +410,8 @@ class Trainer:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
self.__maybe_save_checkpoint(global_step, must_save=True)
|
||||
self.validate(global_step)
|
||||
if self.args.do_validation:
|
||||
self.validate(global_step)
|
||||
|
||||
del self.components
|
||||
free_memory()
|
||||
@ -390,12 +420,124 @@ class Trainer:
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
def validate(self, step: int) -> None:
|
||||
logger.info("Starting validation")
|
||||
|
||||
accelerator = self.accelerator
|
||||
num_validation_samples = len(self.state.validation_prompts)
|
||||
|
||||
if num_validation_samples == 0:
|
||||
logger.warning("No validation samples found. Skipping validation.")
|
||||
return
|
||||
|
||||
self.components.transformer.eval()
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
|
||||
|
||||
all_processes_artifacts = []
|
||||
for i in range(num_validation_samples):
|
||||
# Skip current validation on all processes but one
|
||||
if i % accelerator.num_processes != accelerator.process_index:
|
||||
continue
|
||||
|
||||
prompt = self.state.validation_prompts[i]
|
||||
image = self.state.validation_images[i]
|
||||
video = self.state.validation_videos[i]
|
||||
|
||||
if image is not None:
|
||||
image = preprocess_image_with_resize(
|
||||
image, self.state.gen_height, self.state.gen_width
|
||||
)
|
||||
# Convert image tensor (C, H, W) to PIL images
|
||||
image = image.to(torch.uint8)
|
||||
image = image.permute(1, 2, 0).cpu().numpy()
|
||||
image = Image.fromarray(image)
|
||||
|
||||
if video is not None:
|
||||
video = preprocess_video_with_resize(
|
||||
video, self.state.gen_frames, self.state.gen_height, self.state.gen_width
|
||||
)
|
||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
||||
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
|
||||
|
||||
logger.debug(
|
||||
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||
main_process_only=False,
|
||||
)
|
||||
validation_artifacts = self.validation_step({
|
||||
"prompt": prompt,
|
||||
"image": image,
|
||||
"video": video
|
||||
})
|
||||
prompt_filename = string_to_filename(prompt)[:25]
|
||||
artifacts = {
|
||||
"image": {"type": "image", "value": image},
|
||||
"video": {"type": "video", "value": video},
|
||||
}
|
||||
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
||||
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
|
||||
logger.debug(
|
||||
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
||||
main_process_only=False,
|
||||
)
|
||||
|
||||
for key, value in list(artifacts.items()):
|
||||
artifact_type = value["type"]
|
||||
artifact_value = value["value"]
|
||||
if artifact_type not in ["image", "video"] or artifact_value is None:
|
||||
continue
|
||||
|
||||
extension = "png" if artifact_type == "image" else "mp4"
|
||||
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
|
||||
validation_path = self.args.output_dir / "validation_res"
|
||||
validation_path.mkdir(parents=True, exist_ok=True)
|
||||
filename = str(validation_path / filename)
|
||||
|
||||
if artifact_type == "image":
|
||||
logger.debug(f"Saving image to {filename}")
|
||||
artifact_value.save(filename)
|
||||
artifact_value = wandb.Image(filename)
|
||||
elif artifact_type == "video":
|
||||
logger.debug(f"Saving video to {filename}")
|
||||
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
|
||||
artifact_value = wandb.Video(filename, caption=prompt)
|
||||
|
||||
all_processes_artifacts.append(artifact_value)
|
||||
|
||||
all_artifacts = gather_object(all_processes_artifacts)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
tracker_key = "validation"
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
|
||||
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
|
||||
tracker.log(
|
||||
{
|
||||
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
free_memory()
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||
|
||||
self.components.transformer.train()
|
||||
|
||||
def fit(self):
|
||||
self.prepare_dataset()
|
||||
self.prepare_models()
|
||||
self.prepare_trainable_parameters()
|
||||
self.prepare_optimizer()
|
||||
self.prepare_for_training()
|
||||
if self.args.do_validation:
|
||||
self.prepare_for_validation()
|
||||
self.prepare_trackers()
|
||||
self.train()
|
||||
|
||||
@ -412,7 +554,7 @@ class Trainer:
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self) -> None:
|
||||
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __get_training_dtype(self) -> torch.dtype:
|
||||
|
Loading…
x
Reference in New Issue
Block a user