feat: add xDiT support fro parallel inference

This commit is contained in:
Xibo Sun 2024-09-26 11:21:32 +08:00
parent 628f736628
commit acc7eac759
5 changed files with 161 additions and 0 deletions

View File

@ -358,6 +358,9 @@ This folder contains some tools for model conversion / caption generation, etc.
Adapter.
+ [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): Automatically generate videos using an
open-source local large language model + Flux + CogVideoX.
+ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py):
Supported by [xDiT](https://github.com/xdit-project/xDiT), parallelize the
video generation process on multiple GPUs.
## CogVideo(ICLR'23)

View File

@ -329,6 +329,9 @@ pipe.vae.enable_tiling()
をロードするためのツールコード。
+ [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): オープンソースのローカル大規模言語モデル +
Flux + CogVideoX を使用して自動的に動画を生成します。
+ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py)
[xDiT](https://github.com/xdit-project/xDiT)
によってサポートされ、ビデオ生成プロセスを複数の GPU で並列化します。
## CogVideo(ICLR'23)

View File

@ -312,6 +312,9 @@ pipe.vae.enable_tiling()
+ [load_cogvideox_lora](tools/load_cogvideox_lora.py): 载入diffusers版微调Lora Adapter的工具代码。
+ [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): 使用开源本地大语言模型 + Flux +
CogVideoX实现自动化生成视频。
+ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py):
在多个 GPU 上并行化视频生成过程,
由[xDiT](https://github.com/xdit-project/xDiT)提供支持。
## CogVideo(ICLR'23)

View File

@ -0,0 +1,101 @@
"""
This is a parallel inference script for CogVideo. The original script
can be found from the xDiT project at
https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py
By using this code, the inference process is parallelized on multiple GPUs,
and thus speeded up.
You can also use the run.sh file in the same folder to automate running this
code for batch generation of videos.
Usage:
1. pip install xfuser
2. run the following command to generate video
torchrun --nproc_per_node=4 cogvideox_xdit.py --model <cogvideox-model-path> \
--ring_degree 2 --use_cfg_parallel --height 480 --width 720 --num_frames 9 \
--prompt 'A small dog.'
"""
import time
import torch
import torch.distributed
from diffusers import AutoencoderKLTemporalDecoder
from xfuser import xFuserCogVideoXPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
)
from diffusers.utils import export_to_video
def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
# Check if ulysses_degree is valid
num_heads = 30
if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0:
raise ValueError(
f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})"
)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6,
).frames[0]
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
world_size = get_data_parallel_world_size()
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")
if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,51 @@
set -x
export PYTHONPATH=$PWD:$PYTHONPATH
# Select the model type
# The model is downloaded to a specified location on disk,
# or you can simply use the model's ID on Hugging Face,
# which will then be downloaded to the default cache path on Hugging Face.
export MODEL_TYPE="CogVideoX"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["CogVideoX"]="parallel_inference_xdit.py /cfs/dit/CogVideoX-2b 20"
)
if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}"
export SCRIPT MODEL_ID INFERENCE_STEP
else
echo "Invalid MODEL_TYPE: $MODEL_TYPE"
exit 1
fi
mkdir -p ./results
# task args
if [ "$MODEL_TYPE" = "CogVideoX" ]; then
TASK_ARGS="--height 480 --width 720 --num_frames 9"
fi
# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree.
if [ "$MODEL_TYPE" = "CogVideoX" ]; then
N_GPUS=4
PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
fi
torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "A small dog." \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG