mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
commit
aeb7d9d056
@ -1,6 +1,6 @@
|
||||
# CogVideo & CogVideoX
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
[Read this in English](./README_zh.md)
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
@ -373,4 +373,4 @@ CogVideoX-2B モデル (対応するTransformersモジュールやVAEモジュ
|
||||
[Apache 2.0 License](LICENSE) の下で公開されています。
|
||||
|
||||
CogVideoX-5B モデル (Transformersモジュール) は
|
||||
[CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) の下で公開されています。
|
||||
[CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) の下で公開されています。
|
||||
|
@ -1,6 +1,6 @@
|
||||
# CogVideo & CogVideoX
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
[Read this in English](./README_zh.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
@ -363,4 +363,4 @@ CogVideoX-2B 模型 (包括其对应的Transformers模块,VAE模块) 根据 [A
|
||||
|
||||
CogVideoX-5B 模型 (Transformers 模块)
|
||||
根据 [CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE)
|
||||
许可证发布。
|
||||
许可证发布。
|
||||
|
@ -42,4 +42,6 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
python gradio_web_demo.py
|
||||
```
|
||||
```
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ def pad_image(img, scale):
|
||||
tmp = max(32, int(32 / scale))
|
||||
ph = ((h - 1) // tmp + 1) * tmp
|
||||
pw = ((w - 1) // tmp + 1) * tmp
|
||||
padding = (0, 0, pw - w, ph - h)
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
return F.pad(img, padding)
|
||||
|
||||
|
||||
@ -103,9 +103,13 @@ def rife_inference_with_path(model, video_path):
|
||||
pt_frame_data = []
|
||||
pt_frame = skvideo.io.vreader(video_path)
|
||||
for frame in pt_frame:
|
||||
# BGR to RGB
|
||||
frame_rgb = frame[..., ::-1]
|
||||
frame_rgb = frame_rgb.copy()
|
||||
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
||||
pt_frame_data.append(
|
||||
torch.from_numpy(np.transpose(frame, (2, 0, 1))).to("cpu", non_blocking=True).float() / 255.0
|
||||
)
|
||||
tensor.permute(2, 0, 1)
|
||||
) # to [c, h, w,]
|
||||
|
||||
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
||||
pt_frame = pt_frame.to(device)
|
||||
|
@ -19,9 +19,8 @@ from openai import OpenAI
|
||||
import moviepy.editor as mp
|
||||
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda" # Need to use cuda
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=dtype).to(device)
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
|
@ -362,93 +362,97 @@ class SFTDataset(Dataset):
|
||||
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
||||
"""
|
||||
super(SFTDataset, self).__init__()
|
||||
|
||||
self.video_size = video_size
|
||||
self.fps = fps
|
||||
self.max_num_frames = max_num_frames
|
||||
self.skip_frms_num = skip_frms_num
|
||||
|
||||
self.videos_list = []
|
||||
self.captions_list = []
|
||||
self.num_frames_list = []
|
||||
self.fps_list = []
|
||||
self.video_paths = []
|
||||
self.captions = []
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
for root, dirnames, filenames in os.walk(data_dir):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".mp4"):
|
||||
video_path = os.path.join(root, filename)
|
||||
vr = VideoReader(uri=video_path, height=-1, width=-1)
|
||||
actual_fps = vr.get_avg_fps()
|
||||
ori_vlen = len(vr)
|
||||
self.video_paths.append(video_path)
|
||||
|
||||
if ori_vlen / actual_fps * fps > max_num_frames:
|
||||
num_frames = max_num_frames
|
||||
start = int(skip_frms_num)
|
||||
end = int(start + num_frames / fps * actual_fps)
|
||||
end_safty = min(int(start + num_frames / fps * actual_fps), int(ori_vlen))
|
||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
if ori_vlen > max_num_frames:
|
||||
num_frames = max_num_frames
|
||||
start = int(skip_frms_num)
|
||||
end = int(ori_vlen - skip_frms_num)
|
||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
|
||||
def nearest_smaller_4k_plus_1(n):
|
||||
remainder = n % 4
|
||||
if remainder == 0:
|
||||
return n - 3
|
||||
else:
|
||||
return n - remainder + 1
|
||||
|
||||
start = int(skip_frms_num)
|
||||
end = int(ori_vlen - skip_frms_num)
|
||||
num_frames = nearest_smaller_4k_plus_1(
|
||||
end - start
|
||||
) # 3D VAE requires the number of frames to be 4k+1
|
||||
end = int(start + num_frames)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
|
||||
tensor_frms = pad_last_frame(
|
||||
tensor_frms, max_num_frames
|
||||
) # the len of indices may be less than num_frames, due to round error
|
||||
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
||||
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
|
||||
tensor_frms = (tensor_frms - 127.5) / 127.5
|
||||
self.videos_list.append(tensor_frms)
|
||||
|
||||
# caption
|
||||
caption_path = os.path.join(root, filename.replace(".mp4", ".txt")).replace("videos", "labels")
|
||||
caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels")
|
||||
if os.path.exists(caption_path):
|
||||
caption = open(caption_path, "r").read().splitlines()[0]
|
||||
else:
|
||||
caption = ""
|
||||
self.captions_list.append(caption)
|
||||
self.num_frames_list.append(num_frames)
|
||||
self.fps_list.append(fps)
|
||||
self.captions.append(caption)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
video_path = self.video_paths[index]
|
||||
vr = VideoReader(uri=video_path, height=-1, width=-1)
|
||||
actual_fps = vr.get_avg_fps()
|
||||
ori_vlen = len(vr)
|
||||
|
||||
if ori_vlen / actual_fps * self.fps > self.max_num_frames:
|
||||
num_frames = self.max_num_frames
|
||||
start = int(self.skip_frms_num)
|
||||
end = int(start + num_frames / self.fps * actual_fps)
|
||||
end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen))
|
||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
if ori_vlen > self.max_num_frames:
|
||||
num_frames = self.max_num_frames
|
||||
start = int(self.skip_frms_num)
|
||||
end = int(ori_vlen - self.skip_frms_num)
|
||||
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
|
||||
def nearest_smaller_4k_plus_1(n):
|
||||
remainder = n % 4
|
||||
if remainder == 0:
|
||||
return n - 3
|
||||
else:
|
||||
return n - remainder + 1
|
||||
|
||||
start = int(self.skip_frms_num)
|
||||
end = int(ori_vlen - self.skip_frms_num)
|
||||
num_frames = nearest_smaller_4k_plus_1(
|
||||
end - start
|
||||
) # 3D VAE requires the number of frames to be 4k+1
|
||||
end = int(start + num_frames)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
|
||||
tensor_frms = pad_last_frame(
|
||||
tensor_frms, self.max_num_frames
|
||||
) # the len of indices may be less than num_frames, due to round error
|
||||
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
||||
tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center")
|
||||
tensor_frms = (tensor_frms - 127.5) / 127.5
|
||||
|
||||
item = {
|
||||
"mp4": self.videos_list[index],
|
||||
"txt": self.captions_list[index],
|
||||
"num_frames": self.num_frames_list[index],
|
||||
"fps": self.fps_list[index],
|
||||
"mp4": tensor_frms,
|
||||
"txt": self.captions[index],
|
||||
"num_frames": num_frames,
|
||||
"fps": self.fps,
|
||||
}
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.fps_list)
|
||||
return len(self.video_paths)
|
||||
|
||||
@classmethod
|
||||
def create_dataset_function(cls, path, args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user