Update parallel_inference_xdit.py

This commit is contained in:
DefTruth 2024-11-05 20:21:06 +08:00 committed by GitHub
parent 3710a612d8
commit bb69713fbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,11 +61,14 @@ def main():
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
torch.cuda.reset_peak_memory_stats()
start_time = time.time()