Update rife_model.py

This commit is contained in:
zR 2024-08-27 20:10:23 +08:00
parent f57945811d
commit bfac01abcf

View File

@ -46,7 +46,8 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
_, _, h, w = frame.shape _, _, h, w = frame.shape
I0 = samples[b: b + 1] I0 = samples[b: b + 1]
I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:] I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:]
I1 = pad_image(I1, upscale_amount) I0 = pad_image(I0, upscale_amount).to(torch.float)
I1 = pad_image(I1, upscale_amount).to(torch.float)
# [c, h, w] # [c, h, w]
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
@ -55,7 +56,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
if ssim > 0.996: if ssim > 0.996:
I1 = I0 I1 = I0
I1 = pad_image(I1, upscale_amount) I1 = I1
I1 = make_inference(model, I0, I1, upscale_amount, 1) I1 = make_inference(model, I0, I1, upscale_amount, 1)
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False) I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
@ -74,6 +75,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
frame = pad_image(frame, upscale_amount) frame = pad_image(frame, upscale_amount)
tmp_output = [frame] + tmp_output tmp_output = [frame] + tmp_output
for i, frame in enumerate(tmp_output): for i, frame in enumerate(tmp_output):
frame = F.interpolate(frame, size=(h, w))
output.append(frame.to(output_device)) output.append(frame.to(output_device))
return output return output