mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Update parallel_inference_xdit.py
This commit is contained in:
parent
3710a612d8
commit
bb69713fbb
@ -61,11 +61,14 @@ def main():
|
|||||||
)
|
)
|
||||||
if args.enable_sequential_cpu_offload:
|
if args.enable_sequential_cpu_offload:
|
||||||
pipe.enable_model_cpu_offload(gpu_id=local_rank)
|
pipe.enable_model_cpu_offload(gpu_id=local_rank)
|
||||||
pipe.vae.enable_tiling()
|
|
||||||
else:
|
else:
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
|
|
||||||
|
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
|
||||||
|
pipe.vae.enable_slicing()
|
||||||
|
pipe.vae.enable_tiling()
|
||||||
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user