diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py index ba9b0b7..3a62eff 100644 --- a/finetune/train_cogvideox_image_to_video_lora.py +++ b/finetune/train_cogvideox_image_to_video_lora.py @@ -58,6 +58,7 @@ from torchvision.transforms.functional import center_crop, resize from torchvision.transforms import InterpolationMode import torchvision.transforms as TT import numpy as np +from diffusers.image_processor import VaeImageProcessor if is_wandb_available(): @@ -773,8 +774,13 @@ def log_validation( videos = [] for _ in range(args.num_validation_videos): - video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0] - videos.append(video) + pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0] + pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])]) + + image_np = VaeImageProcessor.pt_to_numpy(pt_images) + image_pil = VaeImageProcessor.numpy_to_pil(image_np) + + videos.append(image_pil) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation"