Merge pull request #286 from glide-the/rife_bug_resize

修复padding后帧变化太大引起的图片大小被拉伸
This commit is contained in:
Yuxuan.Zhang 2024-09-14 12:06:47 +08:00 committed by GitHub
commit 405d1bfdde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,7 @@ import numpy as np
import logging
import skvideo.io
from rife.RIFE_HDv3 import Model
from huggingface_hub import hf_hub_download, snapshot_download
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -19,8 +19,9 @@ def pad_image(img, scale):
tmp = max(32, int(32 / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
return F.pad(img, padding)
padding = (0, pw - w, 0, ph - h)
return F.pad(img, padding), padding
def make_inference(model, I0, I1, upscale_amount, n):
@ -44,10 +45,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
for b in range(samples.shape[0]):
frame = samples[b : b + 1]
_, _, h, w = frame.shape
I0 = samples[b : b + 1]
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
I0 = pad_image(I0, upscale_amount).to(torch.float)
I1 = pad_image(I1, upscale_amount).to(torch.float)
I0, padding = pad_image(I0, upscale_amount)
I0 = I0.to(torch.float)
I1, _ = pad_image(I1, upscale_amount)
I1 = I1.to(torch.float)
# [c, h, w]
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
@ -55,14 +61,24 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if ssim > 0.996:
I1 = I0
I1 = I1
I1 = samples[b : b + 1]
# print(f'upscale_amount:{upscale_amount}')
# print(f'ssim:{upscale_amount}')
# print(f'I0 shape:{I0.shape}')
# print(f'I1 shape:{I1.shape}')
I1, padding = pad_image(I1, upscale_amount)
# print(f'I0 shape:{I0.shape}')
# print(f'I1 shape:{I1.shape}')
I1 = make_inference(model, I0, I1, upscale_amount, 1)
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = I1[0]
# print(f'I0 shape:{I0.shape}')
# print(f'I1[0] shape:{I1[0].shape}')
I1 = I1[0]
# print(f'I1[0] unpadded shape:{I1.shape}')
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = I1[padding[0]:, padding[2]:, padding[3]:,padding[1]:]
tmp_output = []
if ssim < 0.2:
@ -72,9 +88,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
else:
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
frame = pad_image(frame, upscale_amount)
frame, _ = pad_image(frame, upscale_amount)
print(f'frame shape:{frame.shape}')
print(f'tmp_output[0] shape:{tmp_output[0].shape}')
tmp_output = [frame] + tmp_output
for i, frame in enumerate(tmp_output):
for i, frame in enumerate(tmp_output):
frame = F.interpolate(frame, size=(h, w))
output.append(frame.to(output_device))
return output
@ -98,11 +117,19 @@ def frame_generator(video_capture):
def rife_inference_with_path(model, video_path):
# Open the video file
video_capture = cv2.VideoCapture(video_path)
tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT)
fps = video_capture.get(cv2.CAP_PROP_FPS) # Get the frames per second
tot_frame = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # Total frames in the video
pt_frame_data = []
pt_frame = skvideo.io.vreader(video_path)
for frame in pt_frame:
# Cyclic reading of the video frames
while video_capture.isOpened():
ret, frame = video_capture.read()
if not ret:
break
# BGR to RGB
frame_rgb = frame[..., ::-1]
frame_rgb = frame_rgb.copy()
@ -137,3 +164,11 @@ def rife_inference_with_latents(model, latents):
rife_results.append(pt_image)
return torch.stack(rife_results)
if __name__ == "__main__":
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
model = load_rife_model("model_rife")
video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/sat/configs/outputs/1_In_the_heart_of_a_bustling_city,_a_young_woman_with_long,_flowing_brown_hair_and_a_radiant_smile_stands_out._She's_donne/0/000000.mp4")
print(video_path)