From 0ae12e3ea304d0adff1dc322c6915562cac7e2d7 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 5 Nov 2024 22:06:05 +0800 Subject: [PATCH] use original up.upsample --- sat/vae_modules/cp_enc_dec.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py index 15f3356..1d9c34f 100644 --- a/sat/vae_modules/cp_enc_dec.py +++ b/sat/vae_modules/cp_enc_dec.py @@ -960,10 +960,10 @@ class ContextParallelDecoder3D(nn.Module): up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) else: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) - # if i_level < self.num_resolutions - self.temporal_compress_level: - # up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) - # else: - # up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) self.up.insert(0, up) self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)