""" This is a parallel inference script for CogVideo. The original script can be found from the xDiT project at https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py By using this code, the inference process is parallelized on multiple GPUs, and thus speeded up. Usage: 1. pip install xfuser 2. mkdir results 3. run the following command to generate video torchrun --nproc_per_node=4 parallel_inference_xdit.py \ --model --ulysses_degree 1 --ring_degree 2 \ --use_cfg_parallel --height 480 --width 720 --num_frames 9 \ --prompt 'A small dog.' You can also use the run.sh file in the same folder to automate running this code for batch generation of videos, by running: sh ./run.sh """ import time import torch import torch.distributed from diffusers import AutoencoderKLTemporalDecoder from xfuser import xFuserCogVideoXPipeline, xFuserArgs from xfuser.config import FlexibleArgumentParser from xfuser.core.distributed import ( get_world_group, get_data_parallel_rank, get_data_parallel_world_size, get_runtime_state, is_dp_last_group, ) from diffusers.utils import export_to_video def main(): parser = FlexibleArgumentParser(description="xFuser Arguments") args = xFuserArgs.add_cli_args(parser).parse_args() engine_args = xFuserArgs.from_cli_args(args) # Check if ulysses_degree is valid num_heads = 30 if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0: raise ValueError( f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})" ) engine_config, input_config = engine_args.create_config() local_rank = get_world_group().local_rank pipe = xFuserCogVideoXPipeline.from_pretrained( pretrained_model_name_or_path=engine_config.model_config.model, engine_config=engine_config, torch_dtype=torch.bfloat16, ) if args.enable_sequential_cpu_offload: pipe.enable_model_cpu_offload(gpu_id=local_rank) else: device = torch.device(f"cuda:{local_rank}") pipe = pipe.to(device) # Always enable tiling and slicing to avoid VAE OOM while batch size > 1 pipe.vae.enable_slicing() pipe.vae.enable_tiling() torch.cuda.reset_peak_memory_stats() start_time = time.time() output = pipe( height=input_config.height, width=input_config.width, num_frames=input_config.num_frames, prompt=input_config.prompt, num_inference_steps=input_config.num_inference_steps, generator=torch.Generator().manual_seed(input_config.seed), guidance_scale=6, use_dynamic_cfg=True, ).frames[0] end_time = time.time() elapsed_time = end_time - start_time peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") parallel_info = ( f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" f"tp{engine_args.tensor_parallel_degree}_" f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) if is_dp_last_group(): world_size = get_data_parallel_world_size() resolution = f"{input_config.width}x{input_config.height}" output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4" export_to_video(output, output_filename, fps=8) print(f"output saved to {output_filename}") if get_world_group().rank == get_world_group().world_size - 1: print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") get_runtime_state().destory_distributed_env() if __name__ == "__main__": main()