Merge pull request #462 from DefTruth/main

[Parallel] Avoid OOM while batch size > 1
This commit is contained in:
Yuxuan.Zhang 2024-11-06 11:03:42 +08:00 committed by GitHub
commit 4aebdb4b66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,11 +61,14 @@ def main():
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
torch.cuda.reset_peak_memory_stats()
start_time = time.time()