From 5f7f2e424aa6395f759efc5227d06a33331f56a9 Mon Sep 17 00:00:00 2001 From: glide-the Date: Sat, 14 Sep 2024 12:02:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dpadding=E5=90=8E=E5=B8=A7?= =?UTF-8?q?=E5=8F=98=E5=8C=96=E5=A4=AA=E5=A4=A7=E5=BC=95=E8=B5=B7=E7=9A=84?= =?UTF-8?q?=E5=9B=BE=E7=89=87=E5=A4=A7=E5=B0=8F=E8=A2=AB=E6=8B=89=E4=BC=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inference/gradio_composite_demo/rife_model.py | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/inference/gradio_composite_demo/rife_model.py b/inference/gradio_composite_demo/rife_model.py index f66d5e7..455cfff 100644 --- a/inference/gradio_composite_demo/rife_model.py +++ b/inference/gradio_composite_demo/rife_model.py @@ -8,7 +8,7 @@ import numpy as np import logging import skvideo.io from rife.RIFE_HDv3 import Model - +from huggingface_hub import hf_hub_download, snapshot_download logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -19,8 +19,9 @@ def pad_image(img, scale): tmp = max(32, int(32 / scale)) ph = ((h - 1) // tmp + 1) * tmp pw = ((w - 1) // tmp + 1) * tmp - padding = (0, pw - w, 0, ph - h) - return F.pad(img, padding) + padding = (0, pw - w, 0, ph - h) + + return F.pad(img, padding), padding def make_inference(model, I0, I1, upscale_amount, n): @@ -44,10 +45,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi for b in range(samples.shape[0]): frame = samples[b : b + 1] _, _, h, w = frame.shape + I0 = samples[b : b + 1] I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:] - I0 = pad_image(I0, upscale_amount).to(torch.float) - I1 = pad_image(I1, upscale_amount).to(torch.float) + + I0, padding = pad_image(I0, upscale_amount) + I0 = I0.to(torch.float) + I1, _ = pad_image(I1, upscale_amount) + I1 = I1.to(torch.float) + # [c, h, w] I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) @@ -55,14 +61,24 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) if ssim > 0.996: - I1 = I0 - I1 = I1 + I1 = samples[b : b + 1] + # print(f'upscale_amount:{upscale_amount}') + # print(f'ssim:{upscale_amount}') + # print(f'I0 shape:{I0.shape}') + # print(f'I1 shape:{I1.shape}') + I1, padding = pad_image(I1, upscale_amount) + # print(f'I0 shape:{I0.shape}') + # print(f'I1 shape:{I1.shape}') I1 = make_inference(model, I0, I1, upscale_amount, 1) - - I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False) - ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) - frame = I1[0] + + # print(f'I0 shape:{I0.shape}') + # print(f'I1[0] shape:{I1[0].shape}') I1 = I1[0] + + # print(f'I1[0] unpadded shape:{I1.shape}') + I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + frame = I1[padding[0]:, padding[2]:, padding[3]:,padding[1]:] tmp_output = [] if ssim < 0.2: @@ -72,9 +88,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi else: tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else [] - frame = pad_image(frame, upscale_amount) + frame, _ = pad_image(frame, upscale_amount) + print(f'frame shape:{frame.shape}') + print(f'tmp_output[0] shape:{tmp_output[0].shape}') tmp_output = [frame] + tmp_output - for i, frame in enumerate(tmp_output): + + for i, frame in enumerate(tmp_output): frame = F.interpolate(frame, size=(h, w)) output.append(frame.to(output_device)) return output @@ -98,11 +117,19 @@ def frame_generator(video_capture): def rife_inference_with_path(model, video_path): + # Open the video file video_capture = cv2.VideoCapture(video_path) - tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) + fps = video_capture.get(cv2.CAP_PROP_FPS) # Get the frames per second + tot_frame = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # Total frames in the video pt_frame_data = [] pt_frame = skvideo.io.vreader(video_path) - for frame in pt_frame: + # Cyclic reading of the video frames + while video_capture.isOpened(): + ret, frame = video_capture.read() + + if not ret: + break + # BGR to RGB frame_rgb = frame[..., ::-1] frame_rgb = frame_rgb.copy() @@ -137,3 +164,11 @@ def rife_inference_with_latents(model, latents): rife_results.append(pt_image) return torch.stack(rife_results) + + +if __name__ == "__main__": + snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") + model = load_rife_model("model_rife") + + video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/sat/configs/outputs/1_In_the_heart_of_a_bustling_city,_a_young_woman_with_long,_flowing_brown_hair_and_a_radiant_smile_stands_out._She's_donne/0/000000.mp4") + print(video_path) \ No newline at end of file