mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
rife
This commit is contained in:
parent
628f736628
commit
824feef38d
@ -8,8 +8,10 @@ import numpy as np
|
||||
import logging
|
||||
import skvideo.io
|
||||
from rife.RIFE_HDv3 import Model
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@ -18,8 +20,9 @@ def pad_image(img, scale):
|
||||
tmp = max(32, int(32 / scale))
|
||||
ph = ((h - 1) // tmp + 1) * tmp
|
||||
pw = ((w - 1) // tmp + 1) * tmp
|
||||
padding = (0, 0, pw - w, ph - h)
|
||||
return F.pad(img, padding)
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
|
||||
return F.pad(img, padding), padding
|
||||
|
||||
|
||||
def make_inference(model, I0, I1, upscale_amount, n):
|
||||
@ -36,15 +39,23 @@ def make_inference(model, I0, I1, upscale_amount, n):
|
||||
|
||||
@torch.inference_mode()
|
||||
def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu"):
|
||||
|
||||
print(f"samples dtype:{samples.dtype}")
|
||||
print(f"samples shape:{samples.shape}")
|
||||
output = []
|
||||
pbar = utils.ProgressBar(samples.shape[0], desc="RIFE inference")
|
||||
# [f, c, h, w]
|
||||
for b in range(samples.shape[0]):
|
||||
frame = samples[b : b + 1]
|
||||
_, _, h, w = frame.shape
|
||||
|
||||
I0 = samples[b : b + 1]
|
||||
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
||||
I1 = pad_image(I1, upscale_amount)
|
||||
|
||||
I0, padding = pad_image(I0, upscale_amount)
|
||||
I0 = I0.to(torch.float)
|
||||
I1, _ = pad_image(I1, upscale_amount)
|
||||
I1 = I1.to(torch.float)
|
||||
|
||||
# [c, h, w]
|
||||
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||
@ -52,15 +63,25 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||
|
||||
if ssim > 0.996:
|
||||
I1 = I0
|
||||
I1 = pad_image(I1, upscale_amount)
|
||||
I1 = samples[b : b + 1]
|
||||
# print(f'upscale_amount:{upscale_amount}')
|
||||
# print(f'ssim:{upscale_amount}')
|
||||
# print(f'I0 shape:{I0.shape}')
|
||||
# print(f'I1 shape:{I1.shape}')
|
||||
I1, padding = pad_image(I1, upscale_amount)
|
||||
# print(f'I0 shape:{I0.shape}')
|
||||
# print(f'I1 shape:{I1.shape}')
|
||||
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
||||
|
||||
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
|
||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||
frame = I1[0]
|
||||
# print(f'I0 shape:{I0.shape}')
|
||||
# print(f'I1[0] shape:{I1[0].shape}')
|
||||
I1 = I1[0]
|
||||
|
||||
# print(f'I1[0] unpadded shape:{I1.shape}')
|
||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||
frame = I1[padding[0] :, padding[2] :, : -padding[3], padding[1] :]
|
||||
|
||||
tmp_output = []
|
||||
if ssim < 0.2:
|
||||
for i in range((2**exp) - 1):
|
||||
@ -69,10 +90,16 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
else:
|
||||
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
|
||||
|
||||
frame = pad_image(frame, upscale_amount)
|
||||
tmp_output = [frame] + tmp_output
|
||||
for i, frame in enumerate(tmp_output):
|
||||
output.append(frame.to(output_device))
|
||||
frame, _ = pad_image(frame, upscale_amount)
|
||||
# print(f'frame shape:{frame.shape}')
|
||||
|
||||
frame = F.interpolate(frame, size=(h, w))
|
||||
output.append(frame.to(output_device))
|
||||
for i, tmp_frame in enumerate(tmp_output):
|
||||
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
|
||||
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
|
||||
output.append(tmp_frame.to(output_device))
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
@ -94,14 +121,24 @@ def frame_generator(video_capture):
|
||||
|
||||
|
||||
def rife_inference_with_path(model, video_path):
|
||||
# Open the video file
|
||||
video_capture = cv2.VideoCapture(video_path)
|
||||
tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
fps = video_capture.get(cv2.CAP_PROP_FPS) # Get the frames per second
|
||||
tot_frame = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # Total frames in the video
|
||||
pt_frame_data = []
|
||||
pt_frame = skvideo.io.vreader(video_path)
|
||||
for frame in pt_frame:
|
||||
pt_frame_data.append(
|
||||
torch.from_numpy(np.transpose(frame, (2, 0, 1))).to("cpu", non_blocking=True).float() / 255.0
|
||||
)
|
||||
# Cyclic reading of the video frames
|
||||
while video_capture.isOpened():
|
||||
ret, frame = video_capture.read()
|
||||
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# BGR to RGB
|
||||
frame_rgb = frame[..., ::-1]
|
||||
frame_rgb = frame_rgb.copy()
|
||||
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
||||
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
|
||||
|
||||
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
||||
pt_frame = pt_frame.to(device)
|
||||
@ -122,8 +159,17 @@ def rife_inference_with_latents(model, latents):
|
||||
for i in range(latents.size(0)):
|
||||
# [f, c, w, h]
|
||||
latent = latents[i]
|
||||
|
||||
frames = ssim_interpolation_rife(model, latent)
|
||||
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
|
||||
rife_results.append(pt_image)
|
||||
|
||||
return torch.stack(rife_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
||||
model = load_rife_model("model_rife")
|
||||
|
||||
video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/chunk_3710_1.mp4")
|
||||
print(video_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user