mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
commit
67ba369a39
@ -22,6 +22,7 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
|
||||
|
||||
## Project Updates
|
||||
|
||||
- 🔥🔥 **News**: ```2024/9/25```: CogVideoX web demo is available on Replicate. Try the text-to-video model **CogVideoX-5B** here [](https://replicate.com/chenxwh/cogvideox-t2v) and image-to-video model **CogVideoX-5B-I2V** here [](https://replicate.com/chenxwh/cogvideox-i2v).
|
||||
- 🔥🔥 **News**: ```2024/9/19```: We have open-sourced the CogVideoX series image-to-video model **CogVideoX-5B-I2V**.
|
||||
This model can take an image as a background input and generate a video combined with prompt words, offering greater
|
||||
controllability. With this, the CogVideoX series models now support three tasks: text-to-video generation, video
|
||||
@ -358,6 +359,9 @@ This folder contains some tools for model conversion / caption generation, etc.
|
||||
Adapter.
|
||||
+ [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): Automatically generate videos using an
|
||||
open-source local large language model + Flux + CogVideoX.
|
||||
+ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py):
|
||||
Supported by [xDiT](https://github.com/xdit-project/xDiT), parallelize the
|
||||
video generation process on multiple GPUs.
|
||||
|
||||
## CogVideo(ICLR'23)
|
||||
|
||||
|
@ -329,6 +329,9 @@ pipe.vae.enable_tiling()
|
||||
をロードするためのツールコード。
|
||||
+ [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): オープンソースのローカル大規模言語モデル +
|
||||
Flux + CogVideoX を使用して自動的に動画を生成します。
|
||||
+ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py):
|
||||
[xDiT](https://github.com/xdit-project/xDiT)
|
||||
によってサポートされ、ビデオ生成プロセスを複数の GPU で並列化します。
|
||||
|
||||
## CogVideo(ICLR'23)
|
||||
|
||||
|
@ -312,6 +312,9 @@ pipe.vae.enable_tiling()
|
||||
+ [load_cogvideox_lora](tools/load_cogvideox_lora.py): 载入diffusers版微调Lora Adapter的工具代码。
|
||||
+ [llm_flux_cogvideox](tools/llm_flux_cogvideox/llm_flux_cogvideox.py): 使用开源本地大语言模型 + Flux +
|
||||
CogVideoX实现自动化生成视频。
|
||||
+ [parallel_inference_xdit](tools/parallel_inference/parallel_inference_xdit.py):
|
||||
在多个 GPU 上并行化视频生成过程,
|
||||
由[xDiT](https://github.com/xdit-project/xDiT)提供支持。
|
||||
|
||||
## CogVideo(ICLR'23)
|
||||
|
||||
|
@ -133,7 +133,7 @@ def generate_video(
|
||||
video=video, # The path of the video to be used as the background of the video
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_frames=49,
|
||||
# num_frames=49,
|
||||
use_dynamic_cfg=True,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||
|
@ -11,4 +11,5 @@ imageio>=2.35.1
|
||||
imageio-ffmpeg>=0.5.1
|
||||
openai>=1.45.0
|
||||
moviepy>=1.0.3
|
||||
pillow==9.5.0
|
||||
pillow==9.5.0
|
||||
scikit-video
|
||||
|
105
tools/parallel_inference/parallel_inference_xdit.py
Normal file
105
tools/parallel_inference/parallel_inference_xdit.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""
|
||||
This is a parallel inference script for CogVideo. The original script
|
||||
can be found from the xDiT project at
|
||||
|
||||
https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py
|
||||
|
||||
By using this code, the inference process is parallelized on multiple GPUs,
|
||||
and thus speeded up.
|
||||
|
||||
Usage:
|
||||
1. pip install xfuser
|
||||
2. mkdir results
|
||||
3. run the following command to generate video
|
||||
torchrun --nproc_per_node=4 parallel_inference_xdit.py \
|
||||
--model <cogvideox-model-path> --ulysses_degree 1 --ring_degree 2 \
|
||||
--use_cfg_parallel --height 480 --width 720 --num_frames 9 \
|
||||
--prompt 'A small dog.'
|
||||
|
||||
You can also use the run.sh file in the same folder to automate running this
|
||||
code for batch generation of videos, by running:
|
||||
|
||||
sh ./run.sh
|
||||
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed
|
||||
from diffusers import AutoencoderKLTemporalDecoder
|
||||
from xfuser import xFuserCogVideoXPipeline, xFuserArgs
|
||||
from xfuser.config import FlexibleArgumentParser
|
||||
from xfuser.core.distributed import (
|
||||
get_world_group,
|
||||
get_data_parallel_rank,
|
||||
get_data_parallel_world_size,
|
||||
get_runtime_state,
|
||||
is_dp_last_group,
|
||||
)
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
|
||||
def main():
|
||||
parser = FlexibleArgumentParser(description="xFuser Arguments")
|
||||
args = xFuserArgs.add_cli_args(parser).parse_args()
|
||||
engine_args = xFuserArgs.from_cli_args(args)
|
||||
|
||||
# Check if ulysses_degree is valid
|
||||
num_heads = 30
|
||||
if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0:
|
||||
raise ValueError(
|
||||
f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})"
|
||||
)
|
||||
|
||||
engine_config, input_config = engine_args.create_config()
|
||||
local_rank = get_world_group().local_rank
|
||||
|
||||
pipe = xFuserCogVideoXPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=engine_config.model_config.model,
|
||||
engine_config=engine_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
if args.enable_sequential_cpu_offload:
|
||||
pipe.enable_model_cpu_offload(gpu_id=local_rank)
|
||||
pipe.vae.enable_tiling()
|
||||
else:
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
pipe = pipe.to(device)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
start_time = time.time()
|
||||
|
||||
output = pipe(
|
||||
height=input_config.height,
|
||||
width=input_config.width,
|
||||
num_frames=input_config.num_frames,
|
||||
prompt=input_config.prompt,
|
||||
num_inference_steps=input_config.num_inference_steps,
|
||||
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
|
||||
guidance_scale=6,
|
||||
).frames[0]
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
|
||||
|
||||
parallel_info = (
|
||||
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
|
||||
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
|
||||
f"tp{engine_args.tensor_parallel_degree}_"
|
||||
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
|
||||
)
|
||||
if is_dp_last_group():
|
||||
world_size = get_data_parallel_world_size()
|
||||
resolution = f"{input_config.width}x{input_config.height}"
|
||||
output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4"
|
||||
export_to_video(output, output_filename, fps=8)
|
||||
print(f"output saved to {output_filename}")
|
||||
|
||||
if get_world_group().rank == get_world_group().world_size - 1:
|
||||
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
|
||||
get_runtime_state().destory_distributed_env()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
51
tools/parallel_inference/run.sh
Normal file
51
tools/parallel_inference/run.sh
Normal file
@ -0,0 +1,51 @@
|
||||
set -x
|
||||
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
|
||||
# Select the model type
|
||||
# The model is downloaded to a specified location on disk,
|
||||
# or you can simply use the model's ID on Hugging Face,
|
||||
# which will then be downloaded to the default cache path on Hugging Face.
|
||||
|
||||
export MODEL_TYPE="CogVideoX"
|
||||
# Configuration for different model types
|
||||
# script, model_id, inference_step
|
||||
declare -A MODEL_CONFIGS=(
|
||||
["CogVideoX"]="parallel_inference_xdit.py /cfs/dit/CogVideoX-2b 20"
|
||||
)
|
||||
|
||||
if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
|
||||
IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}"
|
||||
export SCRIPT MODEL_ID INFERENCE_STEP
|
||||
else
|
||||
echo "Invalid MODEL_TYPE: $MODEL_TYPE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p ./results
|
||||
|
||||
# task args
|
||||
if [ "$MODEL_TYPE" = "CogVideoX" ]; then
|
||||
TASK_ARGS="--height 480 --width 720 --num_frames 9"
|
||||
fi
|
||||
|
||||
# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree.
|
||||
if [ "$MODEL_TYPE" = "CogVideoX" ]; then
|
||||
N_GPUS=4
|
||||
PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1"
|
||||
CFG_ARGS="--use_cfg_parallel"
|
||||
fi
|
||||
|
||||
|
||||
torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \
|
||||
--model $MODEL_ID \
|
||||
$PARALLEL_ARGS \
|
||||
$TASK_ARGS \
|
||||
$PIPEFUSION_ARGS \
|
||||
$OUTPUT_ARGS \
|
||||
--num_inference_steps $INFERENCE_STEP \
|
||||
--warmup_steps 0 \
|
||||
--prompt "A small dog." \
|
||||
$CFG_ARGS \
|
||||
$PARALLLEL_VAE \
|
||||
$COMPILE_FLAG
|
37
tools/replicate/cog.yaml
Normal file
37
tools/replicate/cog.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# Configuration for Cog ⚙️
|
||||
# Reference: https://cog.run/yaml
|
||||
|
||||
build:
|
||||
# set to true if your model requires a GPU
|
||||
gpu: true
|
||||
|
||||
# a list of ubuntu apt packages to install
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
|
||||
# python version in the form '3.11' or '3.11.4'
|
||||
python_version: "3.11"
|
||||
|
||||
# a list of packages in the format <package-name>==<version>
|
||||
python_packages:
|
||||
- diffusers>=0.30.3
|
||||
- accelerate>=0.34.2
|
||||
- transformers>=4.44.2
|
||||
- numpy==1.26.0
|
||||
- torch>=2.4.0
|
||||
- torchvision>=0.19.0
|
||||
- sentencepiece>=0.2.0
|
||||
- SwissArmyTransformer>=0.4.12
|
||||
- imageio>=2.35.1
|
||||
- imageio-ffmpeg>=0.5.1
|
||||
- openai>=1.45.0
|
||||
- moviepy>=1.0.3
|
||||
- pillow==9.5.0
|
||||
- pydantic==1.10.7
|
||||
run:
|
||||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
||||
|
||||
# predict.py defines how predictions are run on your model
|
||||
predict: "predict_t2v.py:Predictor"
|
||||
# predict: "predict_i2v.py:Predictor"
|
89
tools/replicate/predict_i2v.py
Normal file
89
tools/replicate/predict_i2v.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Prediction interface for Cog ⚙️
|
||||
# https://cog.run/python
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import torch
|
||||
from diffusers import CogVideoXImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from cog import BasePredictor, Input, Path
|
||||
|
||||
|
||||
MODEL_CACHE = "model_cache_i2v"
|
||||
MODEL_URL = (
|
||||
f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar"
|
||||
)
|
||||
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
||||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||||
os.environ["HF_HOME"] = MODEL_CACHE
|
||||
os.environ["TORCH_HOME"] = MODEL_CACHE
|
||||
os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
|
||||
os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
|
||||
os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
|
||||
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
print("downloading to: ", dest)
|
||||
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
||||
print("downloading took: ", time.time() - start)
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self) -> None:
|
||||
"""Load the model into memory to make running multiple predictions efficient"""
|
||||
|
||||
if not os.path.exists(MODEL_CACHE):
|
||||
download_weights(MODEL_URL, MODEL_CACHE)
|
||||
|
||||
# model_id: THUDM/CogVideoX-5b-I2V
|
||||
self.pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
MODEL_CACHE, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
self.pipe.enable_model_cpu_offload()
|
||||
self.pipe.vae.enable_tiling()
|
||||
|
||||
def predict(
|
||||
self,
|
||||
prompt: str = Input(
|
||||
description="Input prompt", default="Starry sky slowly rotating."
|
||||
),
|
||||
image: Path = Input(description="Input image"),
|
||||
num_inference_steps: int = Input(
|
||||
description="Number of denoising steps", ge=1, le=500, default=50
|
||||
),
|
||||
guidance_scale: float = Input(
|
||||
description="Scale for classifier-free guidance", ge=1, le=20, default=6
|
||||
),
|
||||
num_frames: int = Input(
|
||||
description="Number of frames for the output video", default=49
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Random seed. Leave blank to randomize the seed", default=None
|
||||
),
|
||||
) -> Path:
|
||||
"""Run a single prediction on the model"""
|
||||
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(2), "big")
|
||||
print(f"Using seed: {seed}")
|
||||
|
||||
img = load_image(image=str(image))
|
||||
|
||||
video = self.pipe(
|
||||
prompt=prompt,
|
||||
image=img,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device="cuda").manual_seed(seed),
|
||||
).frames[0]
|
||||
|
||||
out_path = "/tmp/out.mp4"
|
||||
|
||||
export_to_video(video, out_path, fps=8)
|
||||
return Path(out_path)
|
87
tools/replicate/predict_t2v.py
Normal file
87
tools/replicate/predict_t2v.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Prediction interface for Cog ⚙️
|
||||
# https://cog.run/python
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from cog import BasePredictor, Input, Path
|
||||
|
||||
|
||||
MODEL_CACHE = "model_cache"
|
||||
MODEL_URL = (
|
||||
f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar"
|
||||
)
|
||||
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
||||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||||
os.environ["HF_HOME"] = MODEL_CACHE
|
||||
os.environ["TORCH_HOME"] = MODEL_CACHE
|
||||
os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
|
||||
os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
|
||||
os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
|
||||
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
print("downloading to: ", dest)
|
||||
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
||||
print("downloading took: ", time.time() - start)
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self) -> None:
|
||||
"""Load the model into memory to make running multiple predictions efficient"""
|
||||
|
||||
if not os.path.exists(MODEL_CACHE):
|
||||
download_weights(MODEL_URL, MODEL_CACHE)
|
||||
|
||||
# model_id: THUDM/CogVideoX-5b
|
||||
self.pipe = CogVideoXPipeline.from_pretrained(
|
||||
MODEL_CACHE,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
self.pipe.enable_model_cpu_offload()
|
||||
self.pipe.vae.enable_tiling()
|
||||
|
||||
def predict(
|
||||
self,
|
||||
prompt: str = Input(
|
||||
description="Input prompt",
|
||||
default="A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.",
|
||||
),
|
||||
num_inference_steps: int = Input(
|
||||
description="Number of denoising steps", ge=1, le=500, default=50
|
||||
),
|
||||
guidance_scale: float = Input(
|
||||
description="Scale for classifier-free guidance", ge=1, le=20, default=6
|
||||
),
|
||||
num_frames: int = Input(
|
||||
description="Number of frames for the output video", default=49
|
||||
),
|
||||
seed: int = Input(
|
||||
description="Random seed. Leave blank to randomize the seed", default=None
|
||||
),
|
||||
) -> Path:
|
||||
"""Run a single prediction on the model"""
|
||||
|
||||
if seed is None:
|
||||
seed = int.from_bytes(os.urandom(2), "big")
|
||||
print(f"Using seed: {seed}")
|
||||
|
||||
video = self.pipe(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device="cuda").manual_seed(seed),
|
||||
).frames[0]
|
||||
|
||||
out_path = "/tmp/out.mp4"
|
||||
|
||||
export_to_video(video, out_path, fps=8)
|
||||
return Path(out_path)
|
Loading…
x
Reference in New Issue
Block a user