Update train_cogvideox_image_to_video_lora.py

This commit is contained in:
zR 2024-10-05 22:12:22 +08:00
parent 4339f65660
commit f28708d845

View File

@ -58,6 +58,7 @@ from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
import torchvision.transforms as TT
import numpy as np
from diffusers.image_processor import VaeImageProcessor
if is_wandb_available():
@ -773,8 +774,13 @@ def log_validation(
videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0]
videos.append(video)
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
videos.append(image_pil)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"