From c84b92e081cc7713224a9a8936152974af23e32d Mon Sep 17 00:00:00 2001 From: Jia Zheng Date: Tue, 27 Aug 2024 21:05:16 +0800 Subject: [PATCH 01/12] :zap: loading video online --- sat/data_video.py | 150 +++++++++++++++++++++++----------------------- 1 file changed, 76 insertions(+), 74 deletions(-) diff --git a/sat/data_video.py b/sat/data_video.py index 04a11ca..505f2e0 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -362,93 +362,95 @@ class SFTDataset(Dataset): skip_frms_num: ignore the first and the last xx frames, avoiding transitions. """ super(SFTDataset, self).__init__() + + self.video_size = video_size + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frms_num = skip_frms_num - self.videos_list = [] - self.captions_list = [] - self.num_frames_list = [] - self.fps_list = [] + self.video_paths = [] - decord.bridge.set_bridge("torch") for root, dirnames, filenames in os.walk(data_dir): for filename in filenames: if filename.endswith(".mp4"): video_path = os.path.join(root, filename) - vr = VideoReader(uri=video_path, height=-1, width=-1) - actual_fps = vr.get_avg_fps() - ori_vlen = len(vr) - - if ori_vlen / actual_fps * fps > max_num_frames: - num_frames = max_num_frames - start = int(skip_frms_num) - end = int(start + num_frames / fps * actual_fps) - end_safty = min(int(start + num_frames / fps * actual_fps), int(ori_vlen)) - indices = np.arange(start, end, (end - start) // num_frames).astype(int) - temp_frms = vr.get_batch(np.arange(start, end_safty)) - assert temp_frms is not None - tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] - else: - if ori_vlen > max_num_frames: - num_frames = max_num_frames - start = int(skip_frms_num) - end = int(ori_vlen - skip_frms_num) - indices = np.arange(start, end, (end - start) // num_frames).astype(int) - temp_frms = vr.get_batch(np.arange(start, end)) - assert temp_frms is not None - tensor_frms = ( - torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - ) - tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] - else: - - def nearest_smaller_4k_plus_1(n): - remainder = n % 4 - if remainder == 0: - return n - 3 - else: - return n - remainder + 1 - - start = int(skip_frms_num) - end = int(ori_vlen - skip_frms_num) - num_frames = nearest_smaller_4k_plus_1( - end - start - ) # 3D VAE requires the number of frames to be 4k+1 - end = int(start + num_frames) - temp_frms = vr.get_batch(np.arange(start, end)) - assert temp_frms is not None - tensor_frms = ( - torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - ) - - tensor_frms = pad_last_frame( - tensor_frms, max_num_frames - ) # the len of indices may be less than num_frames, due to round error - tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] - tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center") - tensor_frms = (tensor_frms - 127.5) / 127.5 - self.videos_list.append(tensor_frms) - - # caption - caption_path = os.path.join(root, filename.replace(".mp4", ".txt")).replace("videos", "labels") - if os.path.exists(caption_path): - caption = open(caption_path, "r").read().splitlines()[0] - else: - caption = "" - self.captions_list.append(caption) - self.num_frames_list.append(num_frames) - self.fps_list.append(fps) + self.video_paths.append(video_path) def __getitem__(self, index): + + decord.bridge.set_bridge("torch") + + video_path = self.video_paths[index] + vr = VideoReader(uri=video_path, height=-1, width=-1) + actual_fps = vr.get_avg_fps() + ori_vlen = len(vr) + + if ori_vlen / actual_fps * self.fps > self.max_num_frames: + num_frames = self.max_num_frames + start = int(self.skip_frms_num) + end = int(start + num_frames / self.fps * actual_fps) + end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen)) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end_safty)) + assert temp_frms is not None + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] + else: + if ori_vlen > self.max_num_frames: + num_frames = self.max_num_frames + start = int(self.skip_frms_num) + end = int(ori_vlen - self.skip_frms_num) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end)) + assert temp_frms is not None + tensor_frms = ( + torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + ) + tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] + else: + + def nearest_smaller_4k_plus_1(n): + remainder = n % 4 + if remainder == 0: + return n - 3 + else: + return n - remainder + 1 + + start = int(self.skip_frms_num) + end = int(ori_vlen - self.skip_frms_num) + num_frames = nearest_smaller_4k_plus_1( + end - start + ) # 3D VAE requires the number of frames to be 4k+1 + end = int(start + num_frames) + temp_frms = vr.get_batch(np.arange(start, end)) + assert temp_frms is not None + tensor_frms = ( + torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + ) + + tensor_frms = pad_last_frame( + tensor_frms, self.max_num_frames + ) # the len of indices may be less than num_frames, due to round error + tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] + tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center") + tensor_frms = (tensor_frms - 127.5) / 127.5 + + caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") + if os.path.exists(caption_path): + caption = open(caption_path, "r").read().splitlines()[0] + else: + caption = "" + item = { - "mp4": self.videos_list[index], - "txt": self.captions_list[index], - "num_frames": self.num_frames_list[index], - "fps": self.fps_list[index], + "mp4": tensor_frms, + "txt": caption, + "num_frames": num_frames, + "fps": self.fps, } return item def __len__(self): - return len(self.fps_list) + return len(self.video_paths) @classmethod def create_dataset_function(cls, path, args, **kwargs): From a6291c985af3059ba7d8117e067fe7e2dc0934a7 Mon Sep 17 00:00:00 2001 From: Jia Zheng Date: Tue, 27 Aug 2024 22:05:37 +0800 Subject: [PATCH 02/12] :zap: preload captions --- sat/data_video.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sat/data_video.py b/sat/data_video.py index 505f2e0..d5d0ea0 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -369,6 +369,7 @@ class SFTDataset(Dataset): self.skip_frms_num = skip_frms_num self.video_paths = [] + self.captions = [] for root, dirnames, filenames in os.walk(data_dir): for filename in filenames: @@ -376,6 +377,13 @@ class SFTDataset(Dataset): video_path = os.path.join(root, filename) self.video_paths.append(video_path) + caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") + if os.path.exists(caption_path): + caption = open(caption_path, "r").read().splitlines()[0] + else: + caption = "" + self.captions.append(caption) + def __getitem__(self, index): decord.bridge.set_bridge("torch") @@ -435,15 +443,9 @@ class SFTDataset(Dataset): tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center") tensor_frms = (tensor_frms - 127.5) / 127.5 - caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") - if os.path.exists(caption_path): - caption = open(caption_path, "r").read().splitlines()[0] - else: - caption = "" - item = { "mp4": tensor_frms, - "txt": caption, + "txt": self.captions[index], "num_frames": num_frames, "fps": self.fps, } From c2a74b7f8dc6f9e3bdd0e236a2f97281aff9c6f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=2E=20Artificial=E6=9B=BE=E5=B0=8F=E5=81=A5?= <875100501@qq.com> Date: Wed, 28 Aug 2024 22:26:14 +0800 Subject: [PATCH 03/12] Update README_zh.md wrong file name, --- README_zh.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README_zh.md b/README_zh.md index a26fa4b..284c5b9 100644 --- a/README_zh.md +++ b/README_zh.md @@ -1,6 +1,6 @@ # CogVideo && CogVideoX -[Read this in English.](./README_zh) +[Read this in English](./README_zh.md) [日本語で読む](./README_ja.md) @@ -351,4 +351,4 @@ CogVideoX-2B 模型 (包括其对应的Transformers模块,VAE模块) 根据 [A CogVideoX-5B 模型 (Transformers 模块) 根据 [CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) -许可证发布。 \ No newline at end of file +许可证发布。 From 0c1dbd144c0f4aba0718f0168deaaf9886b3adfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=2E=20Artificial=E6=9B=BE=E5=B0=8F=E5=81=A5?= <875100501@qq.com> Date: Wed, 28 Aug 2024 22:28:14 +0800 Subject: [PATCH 04/12] fix wrong file name fix wrong file name --- README_ja.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README_ja.md b/README_ja.md index 90a8941..509d5e1 100644 --- a/README_ja.md +++ b/README_ja.md @@ -1,6 +1,6 @@ # CogVideo && CogVideoX -[Read this in English.](./README_zh) +[Read this in English](./README_zh.md) [中文阅读](./README_zh.md) @@ -359,4 +359,4 @@ CogVideoX-2B モデル (対応するTransformersモジュールやVAEモジュ [Apache 2.0 License](LICENSE) の下で公開されています。 CogVideoX-5B モデル (Transformersモジュール) は -[CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) の下で公開されています。 \ No newline at end of file +[CogVideoX LICENSE](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) の下で公開されています。 From e226010ffc54bf5078707c7443c0aa13e09881e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Ant=C3=B4nio=20de=20Ara=C3=BAjo?= Date: Thu, 29 Aug 2024 14:54:19 -0300 Subject: [PATCH 05/12] Remove to device to avoid memory allocation errors Remove to(device), so it will apply the settings first --- inference/gradio_web_demo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/inference/gradio_web_demo.py b/inference/gradio_web_demo.py index a00d1b8..9c5520e 100644 --- a/inference/gradio_web_demo.py +++ b/inference/gradio_web_demo.py @@ -19,9 +19,8 @@ from openai import OpenAI import moviepy.editor as mp dtype = torch.bfloat16 -device = "cuda" # Need to use cuda -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=dtype).to(device) +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=dtype) pipe.enable_model_cpu_offload() pipe.enable_sequential_cpu_offload() pipe.vae.enable_slicing() From 202484e77f5b3b271484fd6dafc7ee0952ac29c4 Mon Sep 17 00:00:00 2001 From: u Date: Sun, 8 Sep 2024 17:09:21 +0800 Subject: [PATCH 06/12] padding fix --- inference/gradio_composite_demo/rife_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/gradio_composite_demo/rife_model.py b/inference/gradio_composite_demo/rife_model.py index e964ef2..b7eac2b 100644 --- a/inference/gradio_composite_demo/rife_model.py +++ b/inference/gradio_composite_demo/rife_model.py @@ -19,7 +19,7 @@ def pad_image(img, scale): tmp = max(32, int(32 / scale)) ph = ((h - 1) // tmp + 1) * tmp pw = ((w - 1) // tmp + 1) * tmp - padding = (0, 0, pw - w, ph - h) + padding = (0, pw - w, 0, ph - h) return F.pad(img, padding) From 88c79c5ecdbcfd3eac0166e97ace6438c8196a03 Mon Sep 17 00:00:00 2001 From: u Date: Sun, 8 Sep 2024 17:11:50 +0800 Subject: [PATCH 07/12] known issue --- inference/gradio_composite_demo/README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/inference/gradio_composite_demo/README.md b/inference/gradio_composite_demo/README.md index 5743cf1..6e6c5e4 100644 --- a/inference/gradio_composite_demo/README.md +++ b/inference/gradio_composite_demo/README.md @@ -42,4 +42,9 @@ pip install -r requirements.txt ```bash python gradio_web_demo.py -``` \ No newline at end of file +``` + + +### known issue + +The known issue is that RIFE may experience precision overflow on different devices, resulting in incorrect image colors. \ No newline at end of file From 41990b0228c1ba158c396a3b65dafaa6aa24983f Mon Sep 17 00:00:00 2001 From: u Date: Mon, 9 Sep 2024 01:10:00 +0800 Subject: [PATCH 08/12] # BGR to RGB --- inference/gradio_composite_demo/rife_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/inference/gradio_composite_demo/rife_model.py b/inference/gradio_composite_demo/rife_model.py index b7eac2b..f66d5e7 100644 --- a/inference/gradio_composite_demo/rife_model.py +++ b/inference/gradio_composite_demo/rife_model.py @@ -103,9 +103,13 @@ def rife_inference_with_path(model, video_path): pt_frame_data = [] pt_frame = skvideo.io.vreader(video_path) for frame in pt_frame: + # BGR to RGB + frame_rgb = frame[..., ::-1] + frame_rgb = frame_rgb.copy() + tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0 pt_frame_data.append( - torch.from_numpy(np.transpose(frame, (2, 0, 1))).to("cpu", non_blocking=True).float() / 255.0 - ) + tensor.permute(2, 0, 1) + ) # to [c, h, w,] pt_frame = torch.from_numpy(np.stack(pt_frame_data)) pt_frame = pt_frame.to(device) From e74e1ae52d33ed06b18066f6cc4d9312e9c1155d Mon Sep 17 00:00:00 2001 From: glide-the Date: Mon, 9 Sep 2024 10:28:03 +0800 Subject: [PATCH 09/12] Update README.md --- inference/gradio_composite_demo/README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/inference/gradio_composite_demo/README.md b/inference/gradio_composite_demo/README.md index 6e6c5e4..3582d86 100644 --- a/inference/gradio_composite_demo/README.md +++ b/inference/gradio_composite_demo/README.md @@ -45,6 +45,3 @@ python gradio_web_demo.py ``` -### known issue - -The known issue is that RIFE may experience precision overflow on different devices, resulting in incorrect image colors. \ No newline at end of file From ccb57e095d6efe35c920ffdfe1c07faa826a3d24 Mon Sep 17 00:00:00 2001 From: Roooy Date: Tue, 10 Sep 2024 17:02:02 +0900 Subject: [PATCH 10/12] fix: division by zero error at create video indices --- sat/data_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sat/data_video.py b/sat/data_video.py index d5d0ea0..f3764c4 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -408,7 +408,7 @@ class SFTDataset(Dataset): num_frames = self.max_num_frames start = int(self.skip_frms_num) end = int(ori_vlen - self.skip_frms_num) - indices = np.arange(start, end, (end - start) // num_frames).astype(int) + indices = np.arange(start, end, max((end - start) // num_frames), 1).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None tensor_frms = ( From 09ac69ecb60af4cc2e9dea1ab9490f4d2af51874 Mon Sep 17 00:00:00 2001 From: Roooy Date: Thu, 12 Sep 2024 10:14:24 +0900 Subject: [PATCH 11/12] fix brackets missing by mistake --- sat/data_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sat/data_video.py b/sat/data_video.py index f3764c4..25d17ee 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -408,7 +408,7 @@ class SFTDataset(Dataset): num_frames = self.max_num_frames start = int(self.skip_frms_num) end = int(ori_vlen - self.skip_frms_num) - indices = np.arange(start, end, max((end - start) // num_frames), 1).astype(int) + indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None tensor_frms = ( From 33b81dd3076d0ffd6a56fce180430bfb85a918b6 Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Thu, 12 Sep 2024 14:51:24 +0800 Subject: [PATCH 12/12] Update data_video.py coding bug --- sat/data_video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sat/data_video.py b/sat/data_video.py index f3764c4..25d17ee 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -408,7 +408,7 @@ class SFTDataset(Dataset): num_frames = self.max_num_frames start = int(self.skip_frms_num) end = int(ori_vlen - self.skip_frms_num) - indices = np.arange(start, end, max((end - start) // num_frames), 1).astype(int) + indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None tensor_frms = (