mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-26 20:29:24 +08:00
update
This commit is contained in:
parent
e26c3c426f
commit
4339f65660
@ -1,5 +1,4 @@
|
|||||||
# Copyright 2024 The HuggingFace Team.
|
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
||||||
# All rights reserved.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -45,10 +44,7 @@ from diffusers import (
|
|||||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||||
from diffusers.training_utils import (
|
from diffusers.training_utils import cast_training_params, free_memory
|
||||||
cast_training_params,
|
|
||||||
clear_objs_and_retain_memory,
|
|
||||||
)
|
|
||||||
from diffusers.utils import (
|
from diffusers.utils import (
|
||||||
check_min_version,
|
check_min_version,
|
||||||
convert_unet_state_dict_to_peft,
|
convert_unet_state_dict_to_peft,
|
||||||
@ -58,6 +54,10 @@ from diffusers.utils import (
|
|||||||
)
|
)
|
||||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||||
from diffusers.utils.torch_utils import is_compiled_module
|
from diffusers.utils.torch_utils import is_compiled_module
|
||||||
|
from torchvision.transforms.functional import center_crop, resize
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
|
import torchvision.transforms as TT
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
if is_wandb_available():
|
if is_wandb_available():
|
||||||
@ -236,6 +236,12 @@ def get_args():
|
|||||||
default=720,
|
default=720,
|
||||||
help="All input videos are resized to this width.",
|
help="All input videos are resized to this width.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--video_reshape_mode",
|
||||||
|
type=str,
|
||||||
|
default="center",
|
||||||
|
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
||||||
|
)
|
||||||
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
|
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
|
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
|
||||||
@ -442,6 +448,7 @@ class VideoDataset(Dataset):
|
|||||||
video_column: str = "video",
|
video_column: str = "video",
|
||||||
height: int = 480,
|
height: int = 480,
|
||||||
width: int = 720,
|
width: int = 720,
|
||||||
|
video_reshape_mode: str = "center",
|
||||||
fps: int = 8,
|
fps: int = 8,
|
||||||
max_num_frames: int = 49,
|
max_num_frames: int = 49,
|
||||||
skip_frames_start: int = 0,
|
skip_frames_start: int = 0,
|
||||||
@ -450,6 +457,7 @@ class VideoDataset(Dataset):
|
|||||||
id_token: Optional[str] = None,
|
id_token: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
|
self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
self.dataset_config_name = dataset_config_name
|
self.dataset_config_name = dataset_config_name
|
||||||
@ -457,12 +465,14 @@ class VideoDataset(Dataset):
|
|||||||
self.video_column = video_column
|
self.video_column = video_column
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
self.video_reshape_mode = video_reshape_mode
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.max_num_frames = max_num_frames
|
self.max_num_frames = max_num_frames
|
||||||
self.skip_frames_start = skip_frames_start
|
self.skip_frames_start = skip_frames_start
|
||||||
self.skip_frames_end = skip_frames_end
|
self.skip_frames_end = skip_frames_end
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.id_token = id_token or ""
|
self.id_token = id_token or ""
|
||||||
|
|
||||||
if dataset_name is not None:
|
if dataset_name is not None:
|
||||||
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
|
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
|
||||||
else:
|
else:
|
||||||
@ -561,6 +571,38 @@ class VideoDataset(Dataset):
|
|||||||
|
|
||||||
return instance_prompts, instance_videos
|
return instance_prompts, instance_videos
|
||||||
|
|
||||||
|
def _resize_for_rectangle_crop(self, arr):
|
||||||
|
image_size = self.height, self.width
|
||||||
|
reshape_mode = self.video_reshape_mode
|
||||||
|
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||||
|
arr = resize(
|
||||||
|
arr,
|
||||||
|
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arr = resize(
|
||||||
|
arr,
|
||||||
|
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
h, w = arr.shape[2], arr.shape[3]
|
||||||
|
arr = arr.squeeze(0)
|
||||||
|
|
||||||
|
delta_h = h - image_size[0]
|
||||||
|
delta_w = w - image_size[1]
|
||||||
|
|
||||||
|
if reshape_mode == "random" or reshape_mode == "none":
|
||||||
|
top = np.random.randint(0, delta_h + 1)
|
||||||
|
left = np.random.randint(0, delta_w + 1)
|
||||||
|
elif reshape_mode == "center":
|
||||||
|
top, left = delta_h // 2, delta_w // 2
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||||
|
return arr
|
||||||
|
|
||||||
def _preprocess_data(self):
|
def _preprocess_data(self):
|
||||||
try:
|
try:
|
||||||
import decord
|
import decord
|
||||||
@ -571,14 +613,15 @@ class VideoDataset(Dataset):
|
|||||||
|
|
||||||
decord.bridge.set_bridge("torch")
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
videos = []
|
progress_dataset_bar = tqdm(
|
||||||
train_transforms = transforms.Compose(
|
range(0, len(self.instance_video_paths)),
|
||||||
[
|
desc="Loading progress resize and crop videos",
|
||||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
videos = []
|
||||||
|
|
||||||
for filename in self.instance_video_paths:
|
for filename in self.instance_video_paths:
|
||||||
|
progress_dataset_bar.update(1)
|
||||||
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
||||||
video_num_frames = len(video_reader)
|
video_num_frames = len(video_reader)
|
||||||
|
|
||||||
@ -605,9 +648,12 @@ class VideoDataset(Dataset):
|
|||||||
assert (selected_num_frames - 1) % 4 == 0
|
assert (selected_num_frames - 1) % 4 == 0
|
||||||
|
|
||||||
# Training transforms
|
# Training transforms
|
||||||
frames = frames.float()
|
frames = (frames - 127.5) / 127.5
|
||||||
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
|
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
|
||||||
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
|
frames = self._resize_for_rectangle_crop(frames)
|
||||||
|
videos.append(frames.contiguous()) # [F, C, H, W]
|
||||||
|
|
||||||
|
progress_dataset_bar.close()
|
||||||
|
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
@ -727,7 +773,7 @@ def log_validation(
|
|||||||
|
|
||||||
videos = []
|
videos = []
|
||||||
for _ in range(args.num_validation_videos):
|
for _ in range(args.num_validation_videos):
|
||||||
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0]
|
||||||
videos.append(video)
|
videos.append(video)
|
||||||
|
|
||||||
for tracker in accelerator.trackers:
|
for tracker in accelerator.trackers:
|
||||||
@ -756,7 +802,8 @@ def log_validation(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
clear_objs_and_retain_memory([pipe])
|
del pipe
|
||||||
|
free_memory()
|
||||||
|
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
@ -1204,6 +1251,7 @@ def main(args):
|
|||||||
video_column=args.video_column,
|
video_column=args.video_column,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
width=args.width,
|
width=args.width,
|
||||||
|
video_reshape_mode=args.video_reshape_mode,
|
||||||
fps=args.fps,
|
fps=args.fps,
|
||||||
max_num_frames=args.max_num_frames,
|
max_num_frames=args.max_num_frames,
|
||||||
skip_frames_start=args.skip_frames_start,
|
skip_frames_start=args.skip_frames_start,
|
||||||
@ -1212,7 +1260,8 @@ def main(args):
|
|||||||
id_token=args.id_token,
|
id_token=args.id_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode_video(video):
|
def encode_video(video, bar):
|
||||||
|
bar.update(1)
|
||||||
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||||
image = video[:, :, :1].clone()
|
image = video[:, :, :1].clone()
|
||||||
@ -1238,7 +1287,13 @@ def main(args):
|
|||||||
)
|
)
|
||||||
for prompt in train_dataset.instance_prompts
|
for prompt in train_dataset.instance_prompts
|
||||||
]
|
]
|
||||||
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
|
|
||||||
|
progress_encode_bar = tqdm(
|
||||||
|
range(0, len(train_dataset.instance_videos)),
|
||||||
|
desc="Loading Encode videos",
|
||||||
|
)
|
||||||
|
train_dataset.instance_videos = [encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos]
|
||||||
|
progress_encode_bar.close()
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
videos = []
|
videos = []
|
||||||
@ -1378,9 +1433,6 @@ def main(args):
|
|||||||
)
|
)
|
||||||
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
|
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||||
|
|
||||||
# Delete VAE and Text Encoder to save memory
|
|
||||||
clear_objs_and_retain_memory([vae, text_encoder])
|
|
||||||
|
|
||||||
# For DeepSpeed training
|
# For DeepSpeed training
|
||||||
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
||||||
|
|
||||||
@ -1550,7 +1602,8 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Cleanup trained models to save memory
|
# Cleanup trained models to save memory
|
||||||
clear_objs_and_retain_memory([transformer])
|
del transformer
|
||||||
|
free_memory()
|
||||||
|
|
||||||
# Final test inference
|
# Final test inference
|
||||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# Copyright 2024 The HuggingFace Team.
|
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
||||||
# All rights reserved.
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user