mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-05 05:09:20 +08:00
Update rife_model.py
This commit is contained in:
parent
f57945811d
commit
bfac01abcf
@ -46,7 +46,8 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
_, _, h, w = frame.shape
|
_, _, h, w = frame.shape
|
||||||
I0 = samples[b: b + 1]
|
I0 = samples[b: b + 1]
|
||||||
I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
||||||
I1 = pad_image(I1, upscale_amount)
|
I0 = pad_image(I0, upscale_amount).to(torch.float)
|
||||||
|
I1 = pad_image(I1, upscale_amount).to(torch.float)
|
||||||
# [c, h, w]
|
# [c, h, w]
|
||||||
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
||||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||||
@ -55,7 +56,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
|
|
||||||
if ssim > 0.996:
|
if ssim > 0.996:
|
||||||
I1 = I0
|
I1 = I0
|
||||||
I1 = pad_image(I1, upscale_amount)
|
I1 = I1
|
||||||
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
||||||
|
|
||||||
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
|
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
|
||||||
@ -74,6 +75,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
frame = pad_image(frame, upscale_amount)
|
frame = pad_image(frame, upscale_amount)
|
||||||
tmp_output = [frame] + tmp_output
|
tmp_output = [frame] + tmp_output
|
||||||
for i, frame in enumerate(tmp_output):
|
for i, frame in enumerate(tmp_output):
|
||||||
|
frame = F.interpolate(frame, size=(h, w))
|
||||||
output.append(frame.to(output_device))
|
output.append(frame.to(output_device))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user