From f0098c06621a9de3f73f7122bf4c9b48b5f37e52 Mon Sep 17 00:00:00 2001 From: glide-the Date: Thu, 3 Oct 2024 13:25:55 +0800 Subject: [PATCH] padding fix --- inference/gradio_composite_demo/rife_model.py | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/inference/gradio_composite_demo/rife_model.py b/inference/gradio_composite_demo/rife_model.py index 901038d..e1783e3 100644 --- a/inference/gradio_composite_demo/rife_model.py +++ b/inference/gradio_composite_demo/rife_model.py @@ -9,7 +9,6 @@ import logging import skvideo.io from rife.RIFE_HDv3 import Model from huggingface_hub import hf_hub_download, snapshot_download - logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -20,8 +19,7 @@ def pad_image(img, scale): tmp = max(32, int(32 / scale)) ph = ((h - 1) // tmp + 1) * tmp pw = ((w - 1) // tmp + 1) * tmp - padding = (0, pw - w, 0, ph - h) - + padding = (0, pw - w, 0, ph - h) return F.pad(img, padding), padding @@ -47,15 +45,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi for b in range(samples.shape[0]): frame = samples[b : b + 1] _, _, h, w = frame.shape - + I0 = samples[b : b + 1] I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:] - + I0, padding = pad_image(I0, upscale_amount) I0 = I0.to(torch.float) I1, _ = pad_image(I1, upscale_amount) I1 = I1.to(torch.float) - + # [c, h, w] I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) @@ -72,15 +70,23 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi # print(f'I0 shape:{I0.shape}') # print(f'I1 shape:{I1.shape}') I1 = make_inference(model, I0, I1, upscale_amount, 1) - + # print(f'I0 shape:{I0.shape}') - # print(f'I1[0] shape:{I1[0].shape}') + # print(f'I1[0] shape:{I1[0].shape}') I1 = I1[0] - - # print(f'I1[0] unpadded shape:{I1.shape}') + + # print(f'I1[0] unpadded shape:{I1.shape}') I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) - frame = I1[padding[0] :, padding[2] :, : -padding[3], padding[1] :] + if padding[3] > 0 and padding[1] >0 : + + frame = I1[:, :, : -padding[3],:-padding[1]] + elif padding[3] > 0: + frame = I1[:, :, : -padding[3],:] + elif padding[1] >0: + frame = I1[:, :, :,:-padding[1]] + else: + frame = I1 tmp_output = [] if ssim < 0.2: @@ -95,7 +101,8 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi frame = F.interpolate(frame, size=(h, w)) output.append(frame.to(output_device)) - for i, tmp_frame in enumerate(tmp_output): + for i, tmp_frame in enumerate(tmp_output): + # tmp_frame, _ = pad_image(tmp_frame, upscale_amount) tmp_frame = F.interpolate(tmp_frame, size=(h, w)) output.append(tmp_frame.to(output_device)) @@ -138,7 +145,9 @@ def rife_inference_with_path(model, video_path): frame_rgb = frame[..., ::-1] frame_rgb = frame_rgb.copy() tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0 - pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,] + pt_frame_data.append( + tensor.permute(2, 0, 1) + ) # to [c, h, w,] pt_frame = torch.from_numpy(np.stack(pt_frame_data)) pt_frame = pt_frame.to(device) @@ -167,9 +176,9 @@ def rife_inference_with_latents(model, latents): return torch.stack(rife_results) -if __name__ == "__main__": - snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") - model = load_rife_model("model_rife") - - video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/chunk_3710_1.mp4") - print(video_path) +# if __name__ == "__main__": +# snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") +# model = load_rife_model("model_rife") + +# video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720.mp4") +# print(video_path) \ No newline at end of file