mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Update rife_model.py
This commit is contained in:
parent
f57945811d
commit
bfac01abcf
@ -42,11 +42,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
output = []
|
||||
# [f, c, h, w]
|
||||
for b in range(samples.shape[0]):
|
||||
frame = samples[b : b + 1]
|
||||
frame = samples[b: b + 1]
|
||||
_, _, h, w = frame.shape
|
||||
I0 = samples[b : b + 1]
|
||||
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
||||
I1 = pad_image(I1, upscale_amount)
|
||||
I0 = samples[b: b + 1]
|
||||
I1 = samples[b + 1: b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
||||
I0 = pad_image(I0, upscale_amount).to(torch.float)
|
||||
I1 = pad_image(I1, upscale_amount).to(torch.float)
|
||||
# [c, h, w]
|
||||
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||
@ -55,7 +56,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
|
||||
if ssim > 0.996:
|
||||
I1 = I0
|
||||
I1 = pad_image(I1, upscale_amount)
|
||||
I1 = I1
|
||||
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
||||
|
||||
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
|
||||
@ -65,15 +66,16 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
|
||||
tmp_output = []
|
||||
if ssim < 0.2:
|
||||
for i in range((2**exp) - 1):
|
||||
for i in range((2 ** exp) - 1):
|
||||
tmp_output.append(I0)
|
||||
|
||||
else:
|
||||
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
|
||||
tmp_output = make_inference(model, I0, I1, upscale_amount, 2 ** exp - 1) if exp else []
|
||||
|
||||
frame = pad_image(frame, upscale_amount)
|
||||
tmp_output = [frame] + tmp_output
|
||||
for i, frame in enumerate(tmp_output):
|
||||
frame = F.interpolate(frame, size=(h, w))
|
||||
output.append(frame.to(output_device))
|
||||
return output
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user