""" This is a parallel inference script for CogVideo. The original script can be found from the xDiT project at https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py By using this code, the inference process is parallelized on multiple GPUs, and thus speeded up. You can also use the run.sh file in the same folder to automate running this code for batch generation of videos. Usage: 1. pip install xfuser 2. run the following command to generate video torchrun --nproc_per_node=4 cogvideox_xdit.py --model \ --ring_degree 2 --use_cfg_parallel --height 480 --width 720 --num_frames 9 \ --prompt 'A small dog.' """ import time import torch import torch.distributed from diffusers import AutoencoderKLTemporalDecoder from xfuser import xFuserCogVideoXPipeline, xFuserArgs from xfuser.config import FlexibleArgumentParser from xfuser.core.distributed import ( get_world_group, get_data_parallel_rank, get_data_parallel_world_size, get_runtime_state, is_dp_last_group, ) from diffusers.utils import export_to_video def main(): parser = FlexibleArgumentParser(description="xFuser Arguments") args = xFuserArgs.add_cli_args(parser).parse_args() engine_args = xFuserArgs.from_cli_args(args) # Check if ulysses_degree is valid num_heads = 30 if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0: raise ValueError( f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})" ) engine_config, input_config = engine_args.create_config() local_rank = get_world_group().local_rank pipe = xFuserCogVideoXPipeline.from_pretrained( pretrained_model_name_or_path=engine_config.model_config.model, engine_config=engine_config, torch_dtype=torch.bfloat16, ) if args.enable_sequential_cpu_offload: pipe.enable_model_cpu_offload(gpu_id=local_rank) pipe.vae.enable_tiling() else: device = torch.device(f"cuda:{local_rank}") pipe = pipe.to(device) torch.cuda.reset_peak_memory_stats() start_time = time.time() output = pipe( height=input_config.height, width=input_config.width, num_frames=input_config.num_frames, prompt=input_config.prompt, num_inference_steps=input_config.num_inference_steps, generator=torch.Generator(device="cuda").manual_seed(input_config.seed), guidance_scale=6, ).frames[0] end_time = time.time() elapsed_time = end_time - start_time peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") parallel_info = ( f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" f"tp{engine_args.tensor_parallel_degree}_" f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) if is_dp_last_group(): world_size = get_data_parallel_world_size() resolution = f"{input_config.width}x{input_config.height}" output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4" export_to_video(output, output_filename, fps=8) print(f"output saved to {output_filename}") if get_world_group().rank == get_world_group().world_size - 1: print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") get_runtime_state().destory_distributed_env() if __name__ == "__main__": main()