mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-19 03:55:53 +08:00
修复padding后帧变化太大引起的图片大小被拉伸
This commit is contained in:
parent
3fb5631b76
commit
5f7f2e424a
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import logging
|
import logging
|
||||||
import skvideo.io
|
import skvideo.io
|
||||||
from rife.RIFE_HDv3 import Model
|
from rife.RIFE_HDv3 import Model
|
||||||
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@ -20,7 +20,8 @@ def pad_image(img, scale):
|
|||||||
ph = ((h - 1) // tmp + 1) * tmp
|
ph = ((h - 1) // tmp + 1) * tmp
|
||||||
pw = ((w - 1) // tmp + 1) * tmp
|
pw = ((w - 1) // tmp + 1) * tmp
|
||||||
padding = (0, pw - w, 0, ph - h)
|
padding = (0, pw - w, 0, ph - h)
|
||||||
return F.pad(img, padding)
|
|
||||||
|
return F.pad(img, padding), padding
|
||||||
|
|
||||||
|
|
||||||
def make_inference(model, I0, I1, upscale_amount, n):
|
def make_inference(model, I0, I1, upscale_amount, n):
|
||||||
@ -44,10 +45,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
frame = samples[b : b + 1]
|
frame = samples[b : b + 1]
|
||||||
_, _, h, w = frame.shape
|
_, _, h, w = frame.shape
|
||||||
|
|
||||||
I0 = samples[b : b + 1]
|
I0 = samples[b : b + 1]
|
||||||
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
||||||
I0 = pad_image(I0, upscale_amount).to(torch.float)
|
|
||||||
I1 = pad_image(I1, upscale_amount).to(torch.float)
|
I0, padding = pad_image(I0, upscale_amount)
|
||||||
|
I0 = I0.to(torch.float)
|
||||||
|
I1, _ = pad_image(I1, upscale_amount)
|
||||||
|
I1 = I1.to(torch.float)
|
||||||
|
|
||||||
# [c, h, w]
|
# [c, h, w]
|
||||||
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
||||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||||
@ -55,15 +61,25 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||||
|
|
||||||
if ssim > 0.996:
|
if ssim > 0.996:
|
||||||
I1 = I0
|
I1 = samples[b : b + 1]
|
||||||
I1 = I1
|
# print(f'upscale_amount:{upscale_amount}')
|
||||||
|
# print(f'ssim:{upscale_amount}')
|
||||||
|
# print(f'I0 shape:{I0.shape}')
|
||||||
|
# print(f'I1 shape:{I1.shape}')
|
||||||
|
I1, padding = pad_image(I1, upscale_amount)
|
||||||
|
# print(f'I0 shape:{I0.shape}')
|
||||||
|
# print(f'I1 shape:{I1.shape}')
|
||||||
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
||||||
|
|
||||||
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
|
# print(f'I0 shape:{I0.shape}')
|
||||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
# print(f'I1[0] shape:{I1[0].shape}')
|
||||||
frame = I1[0]
|
|
||||||
I1 = I1[0]
|
I1 = I1[0]
|
||||||
|
|
||||||
|
# print(f'I1[0] unpadded shape:{I1.shape}')
|
||||||
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||||
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||||
|
frame = I1[padding[0]:, padding[2]:, padding[3]:,padding[1]:]
|
||||||
|
|
||||||
tmp_output = []
|
tmp_output = []
|
||||||
if ssim < 0.2:
|
if ssim < 0.2:
|
||||||
for i in range((2**exp) - 1):
|
for i in range((2**exp) - 1):
|
||||||
@ -72,8 +88,11 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
else:
|
else:
|
||||||
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
|
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
|
||||||
|
|
||||||
frame = pad_image(frame, upscale_amount)
|
frame, _ = pad_image(frame, upscale_amount)
|
||||||
|
print(f'frame shape:{frame.shape}')
|
||||||
|
print(f'tmp_output[0] shape:{tmp_output[0].shape}')
|
||||||
tmp_output = [frame] + tmp_output
|
tmp_output = [frame] + tmp_output
|
||||||
|
|
||||||
for i, frame in enumerate(tmp_output):
|
for i, frame in enumerate(tmp_output):
|
||||||
frame = F.interpolate(frame, size=(h, w))
|
frame = F.interpolate(frame, size=(h, w))
|
||||||
output.append(frame.to(output_device))
|
output.append(frame.to(output_device))
|
||||||
@ -98,11 +117,19 @@ def frame_generator(video_capture):
|
|||||||
|
|
||||||
|
|
||||||
def rife_inference_with_path(model, video_path):
|
def rife_inference_with_path(model, video_path):
|
||||||
|
# Open the video file
|
||||||
video_capture = cv2.VideoCapture(video_path)
|
video_capture = cv2.VideoCapture(video_path)
|
||||||
tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
fps = video_capture.get(cv2.CAP_PROP_FPS) # Get the frames per second
|
||||||
|
tot_frame = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # Total frames in the video
|
||||||
pt_frame_data = []
|
pt_frame_data = []
|
||||||
pt_frame = skvideo.io.vreader(video_path)
|
pt_frame = skvideo.io.vreader(video_path)
|
||||||
for frame in pt_frame:
|
# Cyclic reading of the video frames
|
||||||
|
while video_capture.isOpened():
|
||||||
|
ret, frame = video_capture.read()
|
||||||
|
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
# BGR to RGB
|
# BGR to RGB
|
||||||
frame_rgb = frame[..., ::-1]
|
frame_rgb = frame[..., ::-1]
|
||||||
frame_rgb = frame_rgb.copy()
|
frame_rgb = frame_rgb.copy()
|
||||||
@ -137,3 +164,11 @@ def rife_inference_with_latents(model, latents):
|
|||||||
rife_results.append(pt_image)
|
rife_results.append(pt_image)
|
||||||
|
|
||||||
return torch.stack(rife_results)
|
return torch.stack(rife_results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
||||||
|
model = load_rife_model("model_rife")
|
||||||
|
|
||||||
|
video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/sat/configs/outputs/1_In_the_heart_of_a_bustling_city,_a_young_woman_with_long,_flowing_brown_hair_and_a_radiant_smile_stands_out._She's_donne/0/000000.mp4")
|
||||||
|
print(video_path)
|
Loading…
x
Reference in New Issue
Block a user