mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
padding fix
This commit is contained in:
parent
824feef38d
commit
f0098c0662
@ -9,7 +9,6 @@ import logging
|
|||||||
import skvideo.io
|
import skvideo.io
|
||||||
from rife.RIFE_HDv3 import Model
|
from rife.RIFE_HDv3 import Model
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@ -20,8 +19,7 @@ def pad_image(img, scale):
|
|||||||
tmp = max(32, int(32 / scale))
|
tmp = max(32, int(32 / scale))
|
||||||
ph = ((h - 1) // tmp + 1) * tmp
|
ph = ((h - 1) // tmp + 1) * tmp
|
||||||
pw = ((w - 1) // tmp + 1) * tmp
|
pw = ((w - 1) // tmp + 1) * tmp
|
||||||
padding = (0, pw - w, 0, ph - h)
|
padding = (0, pw - w, 0, ph - h)
|
||||||
|
|
||||||
return F.pad(img, padding), padding
|
return F.pad(img, padding), padding
|
||||||
|
|
||||||
|
|
||||||
@ -47,15 +45,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
frame = samples[b : b + 1]
|
frame = samples[b : b + 1]
|
||||||
_, _, h, w = frame.shape
|
_, _, h, w = frame.shape
|
||||||
|
|
||||||
I0 = samples[b : b + 1]
|
I0 = samples[b : b + 1]
|
||||||
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
||||||
|
|
||||||
I0, padding = pad_image(I0, upscale_amount)
|
I0, padding = pad_image(I0, upscale_amount)
|
||||||
I0 = I0.to(torch.float)
|
I0 = I0.to(torch.float)
|
||||||
I1, _ = pad_image(I1, upscale_amount)
|
I1, _ = pad_image(I1, upscale_amount)
|
||||||
I1 = I1.to(torch.float)
|
I1 = I1.to(torch.float)
|
||||||
|
|
||||||
# [c, h, w]
|
# [c, h, w]
|
||||||
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
||||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||||
@ -72,15 +70,23 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
# print(f'I0 shape:{I0.shape}')
|
# print(f'I0 shape:{I0.shape}')
|
||||||
# print(f'I1 shape:{I1.shape}')
|
# print(f'I1 shape:{I1.shape}')
|
||||||
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
||||||
|
|
||||||
# print(f'I0 shape:{I0.shape}')
|
# print(f'I0 shape:{I0.shape}')
|
||||||
# print(f'I1[0] shape:{I1[0].shape}')
|
# print(f'I1[0] shape:{I1[0].shape}')
|
||||||
I1 = I1[0]
|
I1 = I1[0]
|
||||||
|
|
||||||
# print(f'I1[0] unpadded shape:{I1.shape}')
|
# print(f'I1[0] unpadded shape:{I1.shape}')
|
||||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||||
frame = I1[padding[0] :, padding[2] :, : -padding[3], padding[1] :]
|
if padding[3] > 0 and padding[1] >0 :
|
||||||
|
|
||||||
|
frame = I1[:, :, : -padding[3],:-padding[1]]
|
||||||
|
elif padding[3] > 0:
|
||||||
|
frame = I1[:, :, : -padding[3],:]
|
||||||
|
elif padding[1] >0:
|
||||||
|
frame = I1[:, :, :,:-padding[1]]
|
||||||
|
else:
|
||||||
|
frame = I1
|
||||||
|
|
||||||
tmp_output = []
|
tmp_output = []
|
||||||
if ssim < 0.2:
|
if ssim < 0.2:
|
||||||
@ -95,7 +101,8 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
|
|
||||||
frame = F.interpolate(frame, size=(h, w))
|
frame = F.interpolate(frame, size=(h, w))
|
||||||
output.append(frame.to(output_device))
|
output.append(frame.to(output_device))
|
||||||
for i, tmp_frame in enumerate(tmp_output):
|
for i, tmp_frame in enumerate(tmp_output):
|
||||||
|
|
||||||
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
|
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
|
||||||
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
|
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
|
||||||
output.append(tmp_frame.to(output_device))
|
output.append(tmp_frame.to(output_device))
|
||||||
@ -138,7 +145,9 @@ def rife_inference_with_path(model, video_path):
|
|||||||
frame_rgb = frame[..., ::-1]
|
frame_rgb = frame[..., ::-1]
|
||||||
frame_rgb = frame_rgb.copy()
|
frame_rgb = frame_rgb.copy()
|
||||||
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
||||||
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
|
pt_frame_data.append(
|
||||||
|
tensor.permute(2, 0, 1)
|
||||||
|
) # to [c, h, w,]
|
||||||
|
|
||||||
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
||||||
pt_frame = pt_frame.to(device)
|
pt_frame = pt_frame.to(device)
|
||||||
@ -167,9 +176,9 @@ def rife_inference_with_latents(model, latents):
|
|||||||
return torch.stack(rife_results)
|
return torch.stack(rife_results)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
# snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
||||||
model = load_rife_model("model_rife")
|
# model = load_rife_model("model_rife")
|
||||||
|
|
||||||
video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/chunk_3710_1.mp4")
|
# video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720.mp4")
|
||||||
print(video_path)
|
# print(video_path)
|
Loading…
x
Reference in New Issue
Block a user