import importlib import torch from diffusers.image_processor import VaeImageProcessor from torch.nn import functional as F import cv2 import utils from rife.pytorch_msssim import ssim_matlab import numpy as np import logging import os import skvideo.io logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" def pad_image(img, scale): _, _, h, w = img.shape tmp = max(32, int(32 / scale)) ph = ((h - 1) // tmp + 1) * tmp pw = ((w - 1) // tmp + 1) * tmp padding = (0, 0, pw - w, ph - h) return F.pad(img, padding) def make_inference(model, I0, I1, upscale_amount, n): middle = model.inference(I0, I1, upscale_amount) if n == 1: return [middle] first_half = make_inference(model, I0, middle, upscale_amount, n=n // 2) second_half = make_inference(model, middle, I1, upscale_amount, n=n // 2) if n % 2: return [*first_half, middle, *second_half] else: return [*first_half, *second_half] @torch.inference_mode() def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_device="cpu", pbar=None): print(f"samples dtype:{samples.dtype}") print(f"samples shape:{samples.shape}") output = [] # [f, c, h, w] for b in range(samples.shape[0]): frame = samples[b:b + 1] _, _, h, w = frame.shape I0 = samples[b:b + 1] I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:] I1 = pad_image(I1, upscale_amount) # [c, h, w] I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) if ssim > 0.996: I1 = I0 I1 = pad_image(I1, upscale_amount) I1 = make_inference(model, I0, I1, upscale_amount, 1) I1_small = F.interpolate(I1[0], (32, 32), mode='bilinear', align_corners=False) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) frame = I1[0] I1 = I1[0] tmp_output = [] if ssim < 0.2: for i in range((2 ** exp) - 1): tmp_output.append(I0) else: tmp_output = make_inference(model, I0, I1, upscale_amount, 2 ** exp - 1) if exp else [] # frame to tensor output frame = pad_image(frame, upscale_amount) tmp_output = [frame] + tmp_output # output拼接 tmp_output for i, frame in enumerate(tmp_output): output.append(frame.to(output_device)) if pbar: pbar.update(1) return output def load_rife_model(model_path): pbar = utils.ProgressBar(1, desc="Loading RIFE model") torch.set_grad_enabled(False) if torch.cuda.is_available(): torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.set_default_tensor_type(torch.cuda.FloatTensor) from rife.RIFE_HDv3 import Model model = Model() model.load_model(model_path, -1) print("Loaded v3.x HD model.") model.eval() model.device() pbar.update(1) # Update progress by 1 return model # Create a generator that yields each frame, similar to cv2.VideoCapture def frame_generator(video_capture): while True: ret, frame = video_capture.read() if not ret: break yield frame video_capture.release() def rife_inference_with_path(model, video_path): video_capture = cv2.VideoCapture(video_path) tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) pt_frame_data = [] pt_frame = skvideo.io.vreader(video_path) for frame in pt_frame: # Process each frame pt_frame_data.append( torch.from_numpy( np.transpose(frame, (2, 0, 1)) ) .to("cpu", non_blocking=True).float() / 255. ) pt_frame = torch.from_numpy(np.stack(pt_frame_data)) pt_frame = pt_frame.to(device) pbar = utils.ProgressBar(tot_frame, desc="RIFE inference") frames = ssim_interpolation_rife(model, pt_frame, pbar=pbar) pt_image = torch.stack( [frames[i].squeeze(0) for i in range(len(frames))] ) image_np = VaeImageProcessor.pt_to_numpy(pt_image) # (to [49, 512, 480, 3]) image_pil = VaeImageProcessor.numpy_to_pil(image_np) video_path = utils.save_video(image_pil, fps=16) print(f"Video saved to {video_path}") # frame = next(pt_frame) # I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to("cpu", non_blocking=True).unsqueeze(0).float() / 255. # I1 = pad_image(I1, 1) # I0 = I1 # I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) # I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) # ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) # print(I1.shape) return video_path def rife_inference_with_latents(model, latents): pbar = utils.ProgressBar(latents.shape[1], desc="RIFE inference") rife_results = [] latents = latents.to(device) for i in range(latents.size(0)): # [f, c, w, h] latent = latents[i] frames = ssim_interpolation_rife(model, latent, pbar=pbar) pt_image = torch.stack( [frames[i].squeeze(0) for i in range(len(frames))] ) # (to [f, c, w, h]) rife_results.append(pt_image) return torch.stack(rife_results) # if __name__ == '__main__': # model_path = "/media/gpt4-pdf-chatbot-langchain/ECCV2022-RIFE/train_log" # video_path = "/media/gpt4-pdf-chatbot-langchain/CogVideo/inference/output/20240823_110325.mp4" # model = load_rife_model(model_path) # rife_inference_with_path(model, video_path)