mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-10 17:09:16 +08:00
Update train_cogvideox_image_to_video_lora.py
This commit is contained in:
parent
4339f65660
commit
f28708d845
@ -58,6 +58,7 @@ from torchvision.transforms.functional import center_crop, resize
|
||||
from torchvision.transforms import InterpolationMode
|
||||
import torchvision.transforms as TT
|
||||
import numpy as np
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
@ -773,8 +774,13 @@ def log_validation(
|
||||
|
||||
videos = []
|
||||
for _ in range(args.num_validation_videos):
|
||||
video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0]
|
||||
videos.append(video)
|
||||
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
|
||||
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
|
||||
|
||||
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
|
||||
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||
|
||||
videos.append(image_pil)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
|
Loading…
x
Reference in New Issue
Block a user