mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
commit
d1e45fbb86
@ -61,11 +61,14 @@ def main():
|
||||
)
|
||||
if args.enable_sequential_cpu_offload:
|
||||
pipe.enable_model_cpu_offload(gpu_id=local_rank)
|
||||
pipe.vae.enable_tiling()
|
||||
else:
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
pipe = pipe.to(device)
|
||||
|
||||
# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
start_time = time.time()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user