mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
471 lines
15 KiB
Python
471 lines
15 KiB
Python
import io
|
|
import os
|
|
import sys
|
|
from functools import partial
|
|
import math
|
|
import torchvision.transforms as TT
|
|
from sgm.webds import MetaDistributedWebDataset
|
|
import random
|
|
from fractions import Fraction
|
|
from typing import Union, Optional, Dict, Any, Tuple
|
|
from torchvision.io.video import av
|
|
import numpy as np
|
|
import torch
|
|
from torchvision.io import _video_opt
|
|
from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames
|
|
from torchvision.transforms.functional import center_crop, resize
|
|
from torchvision.transforms import InterpolationMode
|
|
import decord
|
|
from decord import VideoReader
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
def read_video(
|
|
filename: str,
|
|
start_pts: Union[float, Fraction] = 0,
|
|
end_pts: Optional[Union[float, Fraction]] = None,
|
|
pts_unit: str = "pts",
|
|
output_format: str = "THWC",
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
|
"""
|
|
Reads a video from a file, returning both the video frames and the audio frames
|
|
|
|
Args:
|
|
filename (str): path to the video file
|
|
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
|
The start presentation time of the video
|
|
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
|
The end presentation time
|
|
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
|
|
either 'pts' or 'sec'. Defaults to 'pts'.
|
|
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
|
|
|
|
Returns:
|
|
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
|
|
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
|
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
|
"""
|
|
|
|
output_format = output_format.upper()
|
|
if output_format not in ("THWC", "TCHW"):
|
|
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
|
|
|
_check_av_available()
|
|
|
|
if end_pts is None:
|
|
end_pts = float("inf")
|
|
|
|
if end_pts < start_pts:
|
|
raise ValueError(
|
|
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
|
)
|
|
|
|
info = {}
|
|
audio_frames = []
|
|
audio_timebase = _video_opt.default_timebase
|
|
|
|
with av.open(filename, metadata_errors="ignore") as container:
|
|
if container.streams.audio:
|
|
audio_timebase = container.streams.audio[0].time_base
|
|
if container.streams.video:
|
|
video_frames = _read_from_stream(
|
|
container,
|
|
start_pts,
|
|
end_pts,
|
|
pts_unit,
|
|
container.streams.video[0],
|
|
{"video": 0},
|
|
)
|
|
video_fps = container.streams.video[0].average_rate
|
|
# guard against potentially corrupted files
|
|
if video_fps is not None:
|
|
info["video_fps"] = float(video_fps)
|
|
|
|
if container.streams.audio:
|
|
audio_frames = _read_from_stream(
|
|
container,
|
|
start_pts,
|
|
end_pts,
|
|
pts_unit,
|
|
container.streams.audio[0],
|
|
{"audio": 0},
|
|
)
|
|
info["audio_fps"] = container.streams.audio[0].rate
|
|
|
|
aframes_list = [frame.to_ndarray() for frame in audio_frames]
|
|
|
|
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
|
|
|
|
if aframes_list:
|
|
aframes = np.concatenate(aframes_list, 1)
|
|
aframes = torch.as_tensor(aframes)
|
|
if pts_unit == "sec":
|
|
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
|
if end_pts != float("inf"):
|
|
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
|
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
|
else:
|
|
aframes = torch.empty((1, 0), dtype=torch.float32)
|
|
|
|
if output_format == "TCHW":
|
|
# [T,H,W,C] --> [T,C,H,W]
|
|
vframes = vframes.permute(0, 3, 1, 2)
|
|
|
|
return vframes, aframes, info
|
|
|
|
|
|
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
|
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
|
arr = resize(
|
|
arr,
|
|
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
)
|
|
else:
|
|
arr = resize(
|
|
arr,
|
|
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
)
|
|
|
|
h, w = arr.shape[2], arr.shape[3]
|
|
arr = arr.squeeze(0)
|
|
|
|
delta_h = h - image_size[0]
|
|
delta_w = w - image_size[1]
|
|
|
|
if reshape_mode == "random" or reshape_mode == "none":
|
|
top = np.random.randint(0, delta_h + 1)
|
|
left = np.random.randint(0, delta_w + 1)
|
|
elif reshape_mode == "center":
|
|
top, left = delta_h // 2, delta_w // 2
|
|
else:
|
|
raise NotImplementedError
|
|
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
|
return arr
|
|
|
|
|
|
def pad_last_frame(tensor, num_frames):
|
|
# T, H, W, C
|
|
if len(tensor) < num_frames:
|
|
pad_length = num_frames - len(tensor)
|
|
# Use the last frame to pad instead of zero
|
|
last_frame = tensor[-1]
|
|
pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:])
|
|
padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
|
|
return padded_tensor
|
|
else:
|
|
return tensor[:num_frames]
|
|
|
|
|
|
def load_video(
|
|
video_data,
|
|
sampling="uniform",
|
|
duration=None,
|
|
num_frames=4,
|
|
wanted_fps=None,
|
|
actual_fps=None,
|
|
skip_frms_num=0.0,
|
|
nb_read_frames=None,
|
|
):
|
|
decord.bridge.set_bridge("torch")
|
|
vr = VideoReader(uri=video_data, height=-1, width=-1)
|
|
if nb_read_frames is not None:
|
|
ori_vlen = nb_read_frames
|
|
else:
|
|
ori_vlen = min(int(duration * actual_fps) - 1, len(vr))
|
|
|
|
max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps)
|
|
start = random.randint(skip_frms_num, max_seek + 1)
|
|
end = int(start + num_frames / wanted_fps * actual_fps)
|
|
n_frms = num_frames
|
|
|
|
if sampling == "uniform":
|
|
indices = np.arange(start, end, (end - start) / n_frms).astype(int)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# get_batch -> T, H, W, C
|
|
temp_frms = vr.get_batch(np.arange(start, end))
|
|
assert temp_frms is not None
|
|
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
|
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
|
|
|
return pad_last_frame(tensor_frms, num_frames)
|
|
|
|
|
|
import threading
|
|
|
|
|
|
def load_video_with_timeout(*args, **kwargs):
|
|
video_container = {}
|
|
|
|
def target_function():
|
|
video = load_video(*args, **kwargs)
|
|
video_container["video"] = video
|
|
|
|
thread = threading.Thread(target=target_function)
|
|
thread.start()
|
|
timeout = 20
|
|
thread.join(timeout)
|
|
|
|
if thread.is_alive():
|
|
print("Loading video timed out")
|
|
raise TimeoutError
|
|
return video_container.get("video", None).contiguous()
|
|
|
|
|
|
def process_video(
|
|
video_path,
|
|
image_size=None,
|
|
duration=None,
|
|
num_frames=4,
|
|
wanted_fps=None,
|
|
actual_fps=None,
|
|
skip_frms_num=0.0,
|
|
nb_read_frames=None,
|
|
):
|
|
"""
|
|
video_path: str or io.BytesIO
|
|
image_size: .
|
|
duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown.
|
|
num_frames: wanted num_frames.
|
|
wanted_fps: .
|
|
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
|
"""
|
|
|
|
video = load_video_with_timeout(
|
|
video_path,
|
|
duration=duration,
|
|
num_frames=num_frames,
|
|
wanted_fps=wanted_fps,
|
|
actual_fps=actual_fps,
|
|
skip_frms_num=skip_frms_num,
|
|
nb_read_frames=nb_read_frames,
|
|
)
|
|
|
|
# --- copy and modify the image process ---
|
|
video = video.permute(0, 3, 1, 2) # [T, C, H, W]
|
|
|
|
# resize
|
|
if image_size is not None:
|
|
video = resize_for_rectangle_crop(video, image_size, reshape_mode="center")
|
|
|
|
return video
|
|
|
|
|
|
def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"):
|
|
while True:
|
|
r = next(src)
|
|
if "mp4" in r:
|
|
video_data = r["mp4"]
|
|
elif "avi" in r:
|
|
video_data = r["avi"]
|
|
else:
|
|
print("No video data found")
|
|
continue
|
|
|
|
if txt_key not in r:
|
|
txt = ""
|
|
else:
|
|
txt = r[txt_key]
|
|
|
|
if isinstance(txt, bytes):
|
|
txt = txt.decode("utf-8")
|
|
else:
|
|
txt = str(txt)
|
|
|
|
duration = r.get("duration", None)
|
|
if duration is not None:
|
|
duration = float(duration)
|
|
else:
|
|
continue
|
|
|
|
actual_fps = r.get("fps", None)
|
|
if actual_fps is not None:
|
|
actual_fps = float(actual_fps)
|
|
else:
|
|
continue
|
|
|
|
required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num
|
|
required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps
|
|
|
|
if duration is not None and duration < required_duration:
|
|
continue
|
|
|
|
try:
|
|
frames = process_video(
|
|
io.BytesIO(video_data),
|
|
num_frames=num_frames,
|
|
wanted_fps=fps,
|
|
image_size=image_size,
|
|
duration=duration,
|
|
actual_fps=actual_fps,
|
|
skip_frms_num=skip_frms_num,
|
|
)
|
|
frames = (frames - 127.5) / 127.5
|
|
except Exception as e:
|
|
print(e)
|
|
continue
|
|
|
|
item = {
|
|
"mp4": frames,
|
|
"txt": txt,
|
|
"num_frames": num_frames,
|
|
"fps": fps,
|
|
}
|
|
|
|
yield item
|
|
|
|
|
|
class VideoDataset(MetaDistributedWebDataset):
|
|
def __init__(
|
|
self,
|
|
path,
|
|
image_size,
|
|
num_frames,
|
|
fps,
|
|
skip_frms_num=0.0,
|
|
nshards=sys.maxsize,
|
|
seed=1,
|
|
meta_names=None,
|
|
shuffle_buffer=1000,
|
|
include_dirs=None,
|
|
txt_key="caption",
|
|
**kwargs,
|
|
):
|
|
if seed == -1:
|
|
seed = random.randint(0, 1000000)
|
|
if meta_names is None:
|
|
meta_names = []
|
|
|
|
if path.startswith(";"):
|
|
path, include_dirs = path.split(";", 1)
|
|
super().__init__(
|
|
path,
|
|
partial(
|
|
process_fn_video,
|
|
num_frames=num_frames,
|
|
image_size=image_size,
|
|
fps=fps,
|
|
skip_frms_num=skip_frms_num,
|
|
),
|
|
seed,
|
|
meta_names=meta_names,
|
|
shuffle_buffer=shuffle_buffer,
|
|
nshards=nshards,
|
|
include_dirs=include_dirs,
|
|
)
|
|
|
|
@classmethod
|
|
def create_dataset_function(cls, path, args, **kwargs):
|
|
return cls(path, **kwargs)
|
|
|
|
|
|
class SFTDataset(Dataset):
|
|
def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
|
|
"""
|
|
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
|
"""
|
|
super(SFTDataset, self).__init__()
|
|
|
|
self.video_size = video_size
|
|
self.fps = fps
|
|
self.max_num_frames = max_num_frames
|
|
self.skip_frms_num = skip_frms_num
|
|
|
|
self.video_paths = []
|
|
self.captions = []
|
|
|
|
for root, dirnames, filenames in os.walk(data_dir):
|
|
for filename in filenames:
|
|
if filename.endswith(".mp4"):
|
|
video_path = os.path.join(root, filename)
|
|
self.video_paths.append(video_path)
|
|
|
|
caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels")
|
|
if os.path.exists(caption_path):
|
|
caption = open(caption_path, "r").read().splitlines()[0]
|
|
else:
|
|
caption = ""
|
|
self.captions.append(caption)
|
|
|
|
def __getitem__(self, index):
|
|
decord.bridge.set_bridge("torch")
|
|
|
|
video_path = self.video_paths[index]
|
|
vr = VideoReader(uri=video_path, height=-1, width=-1)
|
|
actual_fps = vr.get_avg_fps()
|
|
ori_vlen = len(vr)
|
|
|
|
if ori_vlen / actual_fps * self.fps > self.max_num_frames:
|
|
num_frames = self.max_num_frames
|
|
start = int(self.skip_frms_num)
|
|
end = int(start + num_frames / self.fps * actual_fps)
|
|
end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen))
|
|
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
|
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
|
assert temp_frms is not None
|
|
tensor_frms = (
|
|
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
|
)
|
|
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
|
else:
|
|
if ori_vlen > self.max_num_frames:
|
|
num_frames = self.max_num_frames
|
|
start = int(self.skip_frms_num)
|
|
end = int(ori_vlen - self.skip_frms_num)
|
|
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
|
temp_frms = vr.get_batch(np.arange(start, end))
|
|
assert temp_frms is not None
|
|
tensor_frms = (
|
|
torch.from_numpy(temp_frms)
|
|
if type(temp_frms) is not torch.Tensor
|
|
else temp_frms
|
|
)
|
|
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
|
else:
|
|
|
|
def nearest_smaller_4k_plus_1(n):
|
|
remainder = n % 4
|
|
if remainder == 0:
|
|
return n - 3
|
|
else:
|
|
return n - remainder + 1
|
|
|
|
start = int(self.skip_frms_num)
|
|
end = int(ori_vlen - self.skip_frms_num)
|
|
num_frames = nearest_smaller_4k_plus_1(
|
|
end - start
|
|
) # 3D VAE requires the number of frames to be 4k+1
|
|
end = int(start + num_frames)
|
|
temp_frms = vr.get_batch(np.arange(start, end))
|
|
assert temp_frms is not None
|
|
tensor_frms = (
|
|
torch.from_numpy(temp_frms)
|
|
if type(temp_frms) is not torch.Tensor
|
|
else temp_frms
|
|
)
|
|
|
|
tensor_frms = pad_last_frame(
|
|
tensor_frms, self.max_num_frames
|
|
) # the len of indices may be less than num_frames, due to round error
|
|
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
|
tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center")
|
|
tensor_frms = (tensor_frms - 127.5) / 127.5
|
|
|
|
item = {
|
|
"mp4": tensor_frms,
|
|
"txt": self.captions[index],
|
|
"num_frames": num_frames,
|
|
"fps": self.fps,
|
|
}
|
|
return item
|
|
|
|
def __len__(self):
|
|
return len(self.video_paths)
|
|
|
|
@classmethod
|
|
def create_dataset_function(cls, path, args, **kwargs):
|
|
return cls(data_dir=path, **kwargs)
|