mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 11:18:35 +08:00
Update parallel_inference_xdit.py
This commit is contained in:
parent
3710a612d8
commit
bb69713fbb
@ -61,11 +61,14 @@ def main():
|
||||
)
|
||||
if args.enable_sequential_cpu_offload:
|
||||
pipe.enable_model_cpu_offload(gpu_id=local_rank)
|
||||
pipe.vae.enable_tiling()
|
||||
else:
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
pipe = pipe.to(device)
|
||||
|
||||
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
start_time = time.time()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user