Merge branch 'CogVideoX_dev' of github.com:THUDM/CogVideo into CogVideoX_dev

This commit is contained in:
zR 2024-11-08 11:52:56 +08:00
commit 494296b063

View File

@ -61,11 +61,14 @@ def main():
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
torch.cuda.reset_peak_memory_stats()
start_time = time.time()