mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
update diffusers code
This commit is contained in:
parent
a8205b575d
commit
c8c7b62aa1
@ -3,15 +3,15 @@ This script demonstrates how to generate a video using the CogVideoX model with
|
|||||||
The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v),
|
The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v),
|
||||||
and video-to-video (v2v), depending on the input data and different weight.
|
and video-to-video (v2v), depending on the input data and different weight.
|
||||||
|
|
||||||
- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
|
- text-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
|
||||||
- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
|
- video-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
|
||||||
- image-to-video: THUDM/CogVideoX-5b-I2V
|
- image-to-video: THUDM/CogVideoX-5b-I2V or THUDM/CogVideoX1.5-5b-I2V
|
||||||
|
|
||||||
Running the Script:
|
Running the Script:
|
||||||
To run the script, use the following command with appropriate arguments:
|
To run the script, use the following command with appropriate arguments:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v"
|
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v"
|
||||||
```
|
```
|
||||||
|
|
||||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||||
@ -23,7 +23,6 @@ from typing import Literal
|
|||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
CogVideoXPipeline,
|
CogVideoXPipeline,
|
||||||
CogVideoXDDIMScheduler,
|
|
||||||
CogVideoXDPMScheduler,
|
CogVideoXDPMScheduler,
|
||||||
CogVideoXImageToVideoPipeline,
|
CogVideoXImageToVideoPipeline,
|
||||||
CogVideoXVideoToVideoPipeline,
|
CogVideoXVideoToVideoPipeline,
|
||||||
@ -37,6 +36,7 @@ def generate_video(
|
|||||||
model_path: str,
|
model_path: str,
|
||||||
lora_path: str = None,
|
lora_path: str = None,
|
||||||
lora_rank: int = 128,
|
lora_rank: int = 128,
|
||||||
|
num_frames=81,
|
||||||
output_path: str = "./output.mp4",
|
output_path: str = "./output.mp4",
|
||||||
image_or_video_path: str = "",
|
image_or_video_path: str = "",
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
@ -45,6 +45,7 @@ def generate_video(
|
|||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
|
fps: int = 8,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a video based on the given prompt and saves it to the specified path.
|
Generates a video based on the given prompt and saves it to the specified path.
|
||||||
@ -56,11 +57,13 @@ def generate_video(
|
|||||||
- lora_rank (int): The rank of the LoRA weights.
|
- lora_rank (int): The rank of the LoRA weights.
|
||||||
- output_path (str): The path where the generated video will be saved.
|
- output_path (str): The path where the generated video will be saved.
|
||||||
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
||||||
|
- num_frames (int): Number of frames to generate.
|
||||||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||||
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
|
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
|
||||||
- seed (int): The seed for reproducibility.
|
- seed (int): The seed for reproducibility.
|
||||||
|
- fps (int): The frames per second for the generated video.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
||||||
@ -109,11 +112,11 @@ def generate_video(
|
|||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
video_generate = pipe(
|
video_generate = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image=image, # The path of the image to be used as the background of the video
|
image=image, # The path of the image, the resolution of video will be the same as the image for CogVideoX1.5-5B-I2V, otherwise it will be 720 * 480
|
||||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||||
num_inference_steps=num_inference_steps, # Number of inference steps
|
num_inference_steps=num_inference_steps, # Number of inference steps
|
||||||
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after.
|
num_frames=num_frames, # Number of frames to generate
|
||||||
use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False
|
use_dynamic_cfg=True, # This id used for DPM scheduler, for DDIM scheduler, it should be False
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||||
).frames[0]
|
).frames[0]
|
||||||
@ -122,7 +125,7 @@ def generate_video(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_videos_per_prompt=num_videos_per_prompt,
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_frames=49,
|
num_frames=num_frames,
|
||||||
use_dynamic_cfg=True,
|
use_dynamic_cfg=True,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator().manual_seed(seed),
|
generator=torch.Generator().manual_seed(seed),
|
||||||
@ -133,13 +136,12 @@ def generate_video(
|
|||||||
video=video, # The path of the video to be used as the background of the video
|
video=video, # The path of the video to be used as the background of the video
|
||||||
num_videos_per_prompt=num_videos_per_prompt,
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
# num_frames=49,
|
num_frames=num_frames,
|
||||||
use_dynamic_cfg=True,
|
use_dynamic_cfg=True,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||||
).frames[0]
|
).frames[0]
|
||||||
# 5. Export the generated frames to a video file. fps must be 8 for original video.
|
export_to_video(video_generate, output_path, fps=fps)
|
||||||
export_to_video(video_generate, output_path, fps=8)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -152,24 +154,18 @@ if __name__ == "__main__":
|
|||||||
help="The path of the image to be used as the background of the video",
|
help="The path of the image to be used as the background of the video",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model use"
|
||||||
)
|
)
|
||||||
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
|
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
|
||||||
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
|
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
|
||||||
parser.add_argument(
|
parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video")
|
||||||
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
|
||||||
)
|
|
||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||||
parser.add_argument(
|
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||||
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
|
||||||
)
|
parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
|
||||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||||
parser.add_argument(
|
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
|
||||||
"--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -180,6 +176,7 @@ if __name__ == "__main__":
|
|||||||
lora_path=args.lora_path,
|
lora_path=args.lora_path,
|
||||||
lora_rank=args.lora_rank,
|
lora_rank=args.lora_rank,
|
||||||
output_path=args.output_path,
|
output_path=args.output_path,
|
||||||
|
num_frames=args.num_frames,
|
||||||
image_or_video_path=args.image_or_video_path,
|
image_or_video_path=args.image_or_video_path,
|
||||||
num_inference_steps=args.num_inference_steps,
|
num_inference_steps=args.num_inference_steps,
|
||||||
guidance_scale=args.guidance_scale,
|
guidance_scale=args.guidance_scale,
|
||||||
@ -187,4 +184,5 @@ if __name__ == "__main__":
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
generate_type=args.generate_type,
|
generate_type=args.generate_type,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
|
fps=args.fps,
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,7 @@ This script demonstrates how to generate a video from a text prompt using CogVid
|
|||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
||||||
Must install the `torchao`,`torch`,`diffusers`,`accelerate` library FROM SOURCE to use the quantization feature.
|
Must install the `torchao`,`torch` library FROM SOURCE to use the quantization feature.
|
||||||
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
|
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
|
||||||
|
|
||||||
ALL quantization schemes must use with NVIDIA GPUs.
|
ALL quantization schemes must use with NVIDIA GPUs.
|
||||||
@ -51,6 +51,9 @@ def generate_video(
|
|||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
quantization_scheme: str = "fp8",
|
quantization_scheme: str = "fp8",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
num_frames: int = 81,
|
||||||
|
fps: int = 8,
|
||||||
|
seed: int = 42,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a video based on the given prompt and saves it to the specified path.
|
Generates a video based on the given prompt and saves it to the specified path.
|
||||||
@ -65,7 +68,6 @@ def generate_video(
|
|||||||
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
|
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
|
||||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||||
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
|
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
||||||
@ -80,54 +82,38 @@ def generate_video(
|
|||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
)
|
)
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||||
|
|
||||||
# Using with compile will run faster. First time infer will cost ~30min to compile.
|
|
||||||
# pipe.transformer.to(memory_format=torch.channels_last)
|
|
||||||
|
|
||||||
# for FP8 should remove pipe.enable_model_cpu_offload()
|
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
# This is not for FP8 and INT8 and should remove this line
|
|
||||||
# pipe.enable_sequential_cpu_offload()
|
|
||||||
pipe.vae.enable_slicing()
|
pipe.vae.enable_slicing()
|
||||||
pipe.vae.enable_tiling()
|
pipe.vae.enable_tiling()
|
||||||
|
|
||||||
video = pipe(
|
video = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_videos_per_prompt=num_videos_per_prompt,
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_frames=49,
|
num_frames=num_frames,
|
||||||
use_dynamic_cfg=True,
|
use_dynamic_cfg=True,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
generator=torch.Generator(device="cuda").manual_seed(seed),
|
||||||
).frames[0]
|
).frames[0]
|
||||||
|
|
||||||
export_to_video(video, output_path, fps=8)
|
export_to_video(video, output_path, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||||
|
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model")
|
||||||
|
parser.add_argument("--output_path", type=str, default="./output.mp4", help="Path to save generated video")
|
||||||
|
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||||
|
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
|
||||||
|
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
"--quantization_scheme", type=str, default="fp8", choices=["int8", "fp8"], help="Quantization scheme"
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
|
||||||
)
|
|
||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
|
||||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
|
||||||
parser.add_argument(
|
|
||||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16', 'bfloat16')"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--quantization_scheme",
|
|
||||||
type=str,
|
|
||||||
default="bf16",
|
|
||||||
choices=["int8", "fp8"],
|
|
||||||
help="The quantization scheme to use (int8, fp8)",
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
|
||||||
|
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
||||||
@ -140,4 +126,7 @@ if __name__ == "__main__":
|
|||||||
num_videos_per_prompt=args.num_videos_per_prompt,
|
num_videos_per_prompt=args.num_videos_per_prompt,
|
||||||
quantization_scheme=args.quantization_scheme,
|
quantization_scheme=args.quantization_scheme,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
fps=args.fps,
|
||||||
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
diffusers>=0.31.0
|
diffusers>=0.31.0
|
||||||
accelerate>=1.0.1
|
accelerate>=1.1.1
|
||||||
transformers>=4.46.1
|
transformers>=4.46.2
|
||||||
numpy==1.26.0
|
numpy==1.26.0
|
||||||
torch>=2.5.0
|
torch>=2.5.0
|
||||||
torchvision>=0.20.0
|
torchvision>=0.20.0
|
||||||
sentencepiece>=0.2.0
|
sentencepiece>=0.2.0
|
||||||
SwissArmyTransformer>=0.4.12
|
SwissArmyTransformer>=0.4.12
|
||||||
gradio>=5.4.0
|
gradio>=5.5.0
|
||||||
imageio>=2.35.1
|
imageio>=2.35.1
|
||||||
imageio-ffmpeg>=0.5.1
|
imageio-ffmpeg>=0.5.1
|
||||||
openai>=1.53.0
|
openai>=1.54.0
|
||||||
moviepy>=1.0.3
|
moviepy>=1.0.3
|
||||||
scikit-video>=1.1.11
|
scikit-video>=1.1.11
|
Loading…
x
Reference in New Issue
Block a user