padding fix

This commit is contained in:
glide-the 2024-10-03 13:25:55 +08:00
parent 824feef38d
commit f0098c0662

View File

@ -9,7 +9,6 @@ import logging
import skvideo.io
from rife.RIFE_HDv3 import Model
from huggingface_hub import hf_hub_download, snapshot_download
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -21,7 +20,6 @@ def pad_image(img, scale):
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
return F.pad(img, padding), padding
@ -80,7 +78,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
# print(f'I1[0] unpadded shape:{I1.shape}')
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = I1[padding[0] :, padding[2] :, : -padding[3], padding[1] :]
if padding[3] > 0 and padding[1] >0 :
frame = I1[:, :, : -padding[3],:-padding[1]]
elif padding[3] > 0:
frame = I1[:, :, : -padding[3],:]
elif padding[1] >0:
frame = I1[:, :, :,:-padding[1]]
else:
frame = I1
tmp_output = []
if ssim < 0.2:
@ -96,6 +102,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
frame = F.interpolate(frame, size=(h, w))
output.append(frame.to(output_device))
for i, tmp_frame in enumerate(tmp_output):
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
output.append(tmp_frame.to(output_device))
@ -138,7 +145,9 @@ def rife_inference_with_path(model, video_path):
frame_rgb = frame[..., ::-1]
frame_rgb = frame_rgb.copy()
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
pt_frame_data.append(
tensor.permute(2, 0, 1)
) # to [c, h, w,]
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
pt_frame = pt_frame.to(device)
@ -167,9 +176,9 @@ def rife_inference_with_latents(model, latents):
return torch.stack(rife_results)
if __name__ == "__main__":
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
model = load_rife_model("model_rife")
# if __name__ == "__main__":
# snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
# model = load_rife_model("model_rife")
video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/chunk_3710_1.mp4")
print(video_path)
# video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720.mp4")
# print(video_path)