diff --git a/README.md b/README.md index 6e8a37a..2c03bba 100644 --- a/README.md +++ b/README.md @@ -358,6 +358,9 @@ This folder contains some tools for model conversion / caption generation, etc. Adapter. + [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): Automatically generate videos using an open-source local large language model + Flux + CogVideoX. ++ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py): +Supported by [xDiT](https://github.com/xdit-project/xDiT), parallelize the + video generation process on multiple GPUs. ## CogVideo(ICLR'23) diff --git a/README_ja.md b/README_ja.md index 2a8076b..c24aa02 100644 --- a/README_ja.md +++ b/README_ja.md @@ -329,6 +329,9 @@ pipe.vae.enable_tiling() をロードするためのツールコード。 + [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): オープンソースのローカル大規模言語モデル + Flux + CogVideoX を使用して自動的に動画を生成します。 ++ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py): +[xDiT](https://github.com/xdit-project/xDiT) + によってサポートされ、ビデオ生成プロセスを複数の GPU で並列化します。 ## CogVideo(ICLR'23) diff --git a/README_zh.md b/README_zh.md index 371de16..d831ec1 100644 --- a/README_zh.md +++ b/README_zh.md @@ -312,6 +312,9 @@ pipe.vae.enable_tiling() + [load_cogvideox_lora](tools/load_cogvideox_lora.py): 载入diffusers版微调Lora Adapter的工具代码。 + [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): 使用开源本地大语言模型 + Flux + CogVideoX实现自动化生成视频。 ++ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py): +在多个 GPU 上并行化视频生成过程, + 由[xDiT](https://github.com/xdit-project/xDiT)提供支持。 ## CogVideo(ICLR'23) diff --git a/tools/parallel_inference/parallel_inference_xdit.py b/tools/parallel_inference/parallel_inference_xdit.py new file mode 100644 index 0000000..e10f385 --- /dev/null +++ b/tools/parallel_inference/parallel_inference_xdit.py @@ -0,0 +1,101 @@ +""" +This is a parallel inference script for CogVideo. The original script +can be found from the xDiT project at + +https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py + +By using this code, the inference process is parallelized on multiple GPUs, +and thus speeded up. + +You can also use the run.sh file in the same folder to automate running this +code for batch generation of videos. + +Usage: +1. pip install xfuser +2. run the following command to generate video +torchrun --nproc_per_node=4 cogvideox_xdit.py --model \ + --ring_degree 2 --use_cfg_parallel --height 480 --width 720 --num_frames 9 \ + --prompt 'A small dog.' + +""" + +import time +import torch +import torch.distributed +from diffusers import AutoencoderKLTemporalDecoder +from xfuser import xFuserCogVideoXPipeline, xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_data_parallel_rank, + get_data_parallel_world_size, + get_runtime_state, + is_dp_last_group, +) +from diffusers.utils import export_to_video + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + # Check if ulysses_degree is valid + num_heads = 30 + if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0: + raise ValueError( + f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})" + ) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + pipe = xFuserCogVideoXPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + engine_config=engine_config, + torch_dtype=torch.bfloat16, + ) + if args.enable_sequential_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + pipe.vae.enable_tiling() + else: + device = torch.device(f"cuda:{local_rank}") + pipe = pipe.to(device) + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + guidance_scale=6, + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + world_size = get_data_parallel_world_size() + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=8) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") + get_runtime_state().destory_distributed_env() + + +if __name__ == "__main__": + main() diff --git a/tools/parallel_inference/run.sh b/tools/parallel_inference/run.sh new file mode 100644 index 0000000..7f9d5a8 --- /dev/null +++ b/tools/parallel_inference/run.sh @@ -0,0 +1,51 @@ +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# Select the model type +# The model is downloaded to a specified location on disk, +# or you can simply use the model's ID on Hugging Face, +# which will then be downloaded to the default cache path on Hugging Face. + +export MODEL_TYPE="CogVideoX" +# Configuration for different model types +# script, model_id, inference_step +declare -A MODEL_CONFIGS=( + ["CogVideoX"]="parallel_inference_xdit.py /cfs/dit/CogVideoX-2b 20" +) + +if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then + IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" + export SCRIPT MODEL_ID INFERENCE_STEP +else + echo "Invalid MODEL_TYPE: $MODEL_TYPE" + exit 1 +fi + +mkdir -p ./results + +# task args +if [ "$MODEL_TYPE" = "CogVideoX" ]; then + TASK_ARGS="--height 480 --width 720 --num_frames 9" +fi + +# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree. +if [ "$MODEL_TYPE" = "CogVideoX" ]; then +N_GPUS=4 +PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1" +CFG_ARGS="--use_cfg_parallel" +fi + + +torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "A small dog." \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$COMPILE_FLAG