CogVideo/tools/parallel_inference/parallel_inference_xdit.py

102 lines
3.5 KiB
Python

"""
This is a parallel inference script for CogVideo. The original script
can be found from the xDiT project at
https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py
By using this code, the inference process is parallelized on multiple GPUs,
and thus speeded up.
You can also use the run.sh file in the same folder to automate running this
code for batch generation of videos.
Usage:
1. pip install xfuser
2. run the following command to generate video
torchrun --nproc_per_node=4 cogvideox_xdit.py --model <cogvideox-model-path> \
--ring_degree 2 --use_cfg_parallel --height 480 --width 720 --num_frames 9 \
--prompt 'A small dog.'
"""
import time
import torch
import torch.distributed
from diffusers import AutoencoderKLTemporalDecoder
from xfuser import xFuserCogVideoXPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
)
from diffusers.utils import export_to_video
def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
# Check if ulysses_degree is valid
num_heads = 30
if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0:
raise ValueError(
f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})"
)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6,
).frames[0]
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
world_size = get_data_parallel_world_size()
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")
if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()
if __name__ == "__main__":
main()