Merge pull request #468 from THUDM/main

merge
This commit is contained in:
Yuxuan.Zhang 2024-11-07 23:59:21 +08:00 committed by GitHub
commit d1e45fbb86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,11 +61,14 @@ def main():
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
torch.cuda.reset_peak_memory_stats()
start_time = time.time()