mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-11-15 22:42:10 +08:00
Update train_cogvideox_image_to_video_lora.py
This commit is contained in:
parent
4339f65660
commit
f28708d845
@ -58,6 +58,7 @@ from torchvision.transforms.functional import center_crop, resize
|
|||||||
from torchvision.transforms import InterpolationMode
|
from torchvision.transforms import InterpolationMode
|
||||||
import torchvision.transforms as TT
|
import torchvision.transforms as TT
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
|
||||||
|
|
||||||
if is_wandb_available():
|
if is_wandb_available():
|
||||||
@ -773,8 +774,13 @@ def log_validation(
|
|||||||
|
|
||||||
videos = []
|
videos = []
|
||||||
for _ in range(args.num_validation_videos):
|
for _ in range(args.num_validation_videos):
|
||||||
video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0]
|
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
|
||||||
videos.append(video)
|
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
|
||||||
|
|
||||||
|
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
|
||||||
|
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||||
|
|
||||||
|
videos.append(image_pil)
|
||||||
|
|
||||||
for tracker in accelerator.trackers:
|
for tracker in accelerator.trackers:
|
||||||
phase_name = "test" if is_final_validation else "validation"
|
phase_name = "test" if is_final_validation else "validation"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user