diff --git a/finetune/trainer.py b/finetune/trainer.py index cc9882c..501a7e5 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -1,15 +1,18 @@ -import torch +import os import logging -import transformers -import diffusers import math import json -import multiprocessing + +import torch +import transformers +import diffusers +import wandb from datetime import timedelta from pathlib import Path from tqdm import tqdm -from typing import Dict, Any, List +from typing import Dict, Any, List, Tuple +from PIL import Image from torch.utils.data import Dataset, DataLoader from accelerate.logging import get_logger @@ -23,6 +26,7 @@ from accelerate.utils import ( ) from diffusers.optimization import get_scheduler +from diffusers.utils.export_utils import export_to_video from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from finetune.schemas import Args, State, Components @@ -37,8 +41,14 @@ from finetune.utils import ( get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from, get_intermediate_ckpt_path, + + string_to_filename ) from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler +from finetune.datasets.utils import ( + load_prompts, load_images, load_videos, + preprocess_image_with_resize, preprocess_video_with_resize +) from finetune.constants import LOG_NAME, LOG_LEVEL @@ -263,6 +273,24 @@ class Trainer: # Afterwards we recalculate our number of training epochs self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) self.state.num_update_steps_per_epoch = num_update_steps_per_epoch + + def prepare_for_validation(self): + self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')] + validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts) + + if self.args.validation_images is not None: + validation_images = load_images(self.args.validation_dir / self.args.validation_images) + else: + validation_images = [None] * len(validation_prompts) + + if self.args.validation_videos is not None: + validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos) + else: + validation_videos = [None] * len(validation_prompts) + + self.state.validation_prompts = validation_prompts + self.state.validation_images = validation_images + self.state.validation_videos = validation_videos def prepare_trackers(self) -> None: logger.info("Initializing trackers") @@ -319,6 +347,7 @@ class Trainer: generator = torch.Generator(device=accelerator.device) if self.args.seed is not None: generator = generator.manual_seed(self.args.seed) + self.state.generator = generator for epoch in range(first_epoch, self.args.train_epochs): logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") @@ -362,7 +391,7 @@ class Trainer: # Maybe run validation should_run_validation = ( - self.args.validation_steps is not None + self.args.do_validation and global_step % self.args.validation_steps == 0 ) if should_run_validation: @@ -381,7 +410,8 @@ class Trainer: accelerator.wait_for_everyone() self.__maybe_save_checkpoint(global_step, must_save=True) - self.validate(global_step) + if self.args.do_validation: + self.validate(global_step) del self.components free_memory() @@ -390,12 +420,124 @@ class Trainer: accelerator.end_training() + def validate(self, step: int) -> None: + logger.info("Starting validation") + + accelerator = self.accelerator + num_validation_samples = len(self.state.validation_prompts) + + if num_validation_samples == 0: + logger.warning("No validation samples found. Skipping validation.") + return + + self.components.transformer.eval() + + memory_statistics = get_memory_statistics() + logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") + + all_processes_artifacts = [] + for i in range(num_validation_samples): + # Skip current validation on all processes but one + if i % accelerator.num_processes != accelerator.process_index: + continue + + prompt = self.state.validation_prompts[i] + image = self.state.validation_images[i] + video = self.state.validation_videos[i] + + if image is not None: + image = preprocess_image_with_resize( + image, self.state.gen_height, self.state.gen_width + ) + # Convert image tensor (C, H, W) to PIL images + image = image.to(torch.uint8) + image = image.permute(1, 2, 0).cpu().numpy() + image = Image.fromarray(image) + + if video is not None: + video = preprocess_video_with_resize( + video, self.state.gen_frames, self.state.gen_height, self.state.gen_width + ) + # Convert video tensor (F, C, H, W) to list of PIL images + video = (video * 255).round().clamp(0, 255).to(torch.uint8) + video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video] + + logger.debug( + f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", + main_process_only=False, + ) + validation_artifacts = self.validation_step({ + "prompt": prompt, + "image": image, + "video": video + }) + prompt_filename = string_to_filename(prompt)[:25] + artifacts = { + "image": {"type": "image", "value": image}, + "video": {"type": "video", "value": video}, + } + for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): + artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) + logger.debug( + f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", + main_process_only=False, + ) + + for key, value in list(artifacts.items()): + artifact_type = value["type"] + artifact_value = value["value"] + if artifact_type not in ["image", "video"] or artifact_value is None: + continue + + extension = "png" if artifact_type == "image" else "mp4" + filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}" + validation_path = self.args.output_dir / "validation_res" + validation_path.mkdir(parents=True, exist_ok=True) + filename = str(validation_path / filename) + + if artifact_type == "image": + logger.debug(f"Saving image to {filename}") + artifact_value.save(filename) + artifact_value = wandb.Image(filename) + elif artifact_type == "video": + logger.debug(f"Saving video to {filename}") + export_to_video(artifact_value, filename, fps=self.args.gen_fps) + artifact_value = wandb.Video(filename, caption=prompt) + + all_processes_artifacts.append(artifact_value) + + all_artifacts = gather_object(all_processes_artifacts) + + if accelerator.is_main_process: + tracker_key = "validation" + for tracker in accelerator.trackers: + if tracker.name == "wandb": + image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] + video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + tracker.log( + { + tracker_key: {"images": image_artifacts, "videos": video_artifacts}, + }, + step=step, + ) + + accelerator.wait_for_everyone() + + free_memory() + memory_statistics = get_memory_statistics() + logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(accelerator.device) + + self.components.transformer.train() + def fit(self): self.prepare_dataset() self.prepare_models() self.prepare_trainable_parameters() self.prepare_optimizer() self.prepare_for_training() + if self.args.do_validation: + self.prepare_for_validation() self.prepare_trackers() self.train() @@ -412,7 +554,7 @@ class Trainer: def compute_loss(self, batch) -> torch.Tensor: raise NotImplementedError - def validate(self) -> None: + def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]: raise NotImplementedError def __get_training_dtype(self) -> torch.dtype: