Update rife_model.py

This commit is contained in:
zR 2024-08-27 20:10:23 +08:00
parent f57945811d
commit bfac01abcf

View File

@ -42,11 +42,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
output = []
# [f, c, h, w]
for b in range(samples.shape[0]):
frame = samples[b : b + 1]
frame = samples[b: b + 1]
_, _, h, w = frame.shape
I0 = samples[b : b + 1]
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
I1 = pad_image(I1, upscale_amount)
I0 = samples[b: b + 1]
I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:]
I0 = pad_image(I0, upscale_amount).to(torch.float)
I1 = pad_image(I1, upscale_amount).to(torch.float)
# [c, h, w]
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
@ -55,7 +56,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
if ssim > 0.996:
I1 = I0
I1 = pad_image(I1, upscale_amount)
I1 = I1
I1 = make_inference(model, I0, I1, upscale_amount, 1)
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
@ -65,15 +66,16 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
tmp_output = []
if ssim < 0.2:
for i in range((2**exp) - 1):
for i in range((2 ** exp) - 1):
tmp_output.append(I0)
else:
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
tmp_output = make_inference(model, I0, I1, upscale_amount, 2 ** exp - 1) if exp else []
frame = pad_image(frame, upscale_amount)
tmp_output = [frame] + tmp_output
for i, frame in enumerate(tmp_output):
frame = F.interpolate(frame, size=(h, w))
output.append(frame.to(output_device))
return output