diff --git a/inference/gradio_composite_demo/rife_model.py b/inference/gradio_composite_demo/rife_model.py index 0a69ca6..84a4fb6 100644 --- a/inference/gradio_composite_demo/rife_model.py +++ b/inference/gradio_composite_demo/rife_model.py @@ -42,11 +42,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi output = [] # [f, c, h, w] for b in range(samples.shape[0]): - frame = samples[b : b + 1] + frame = samples[b: b + 1] _, _, h, w = frame.shape - I0 = samples[b : b + 1] - I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:] - I1 = pad_image(I1, upscale_amount) + I0 = samples[b: b + 1] + I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:] + I0 = pad_image(I0, upscale_amount).to(torch.float) + I1 = pad_image(I1, upscale_amount).to(torch.float) # [c, h, w] I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) @@ -55,7 +56,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi if ssim > 0.996: I1 = I0 - I1 = pad_image(I1, upscale_amount) + I1 = I1 I1 = make_inference(model, I0, I1, upscale_amount, 1) I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False) @@ -65,15 +66,16 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi tmp_output = [] if ssim < 0.2: - for i in range((2**exp) - 1): + for i in range((2 ** exp) - 1): tmp_output.append(I0) else: - tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else [] + tmp_output = make_inference(model, I0, I1, upscale_amount, 2 ** exp - 1) if exp else [] frame = pad_image(frame, upscale_amount) tmp_output = [frame] + tmp_output for i, frame in enumerate(tmp_output): + frame = F.interpolate(frame, size=(h, w)) output.append(frame.to(output_device)) return output