mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
106 lines
4.4 KiB
Python
106 lines
4.4 KiB
Python
"""
|
|
This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
|
|
|
Note:
|
|
This script requires the `diffusers>=0.30.1` and `torchao>=0.4.0` library to be installed.
|
|
|
|
Run the script:
|
|
$ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
|
|
|
|
In this script, we have only provided the script for testing and inference in INT8 for the entire process
|
|
(including T5 Encoder, CogVideoX Transformer, VAE).
|
|
You can use other functionalities provided by torchao to convert to other precisions.
|
|
Please note that INT4 is not supported.
|
|
"""
|
|
import argparse
|
|
|
|
import torch
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline
|
|
from diffusers.utils import export_to_video
|
|
from transformers import T5EncoderModel
|
|
|
|
# Make sure to install torchao>=0.4.0
|
|
from torchao.quantization import quantize_, int8_weight_only
|
|
|
|
|
|
def generate_video(
|
|
prompt: str,
|
|
model_path: str,
|
|
output_path: str = "./output.mp4",
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 6.0,
|
|
num_videos_per_prompt: int = 1,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
):
|
|
"""
|
|
Generates a video based on the given prompt and saves it to the specified path.
|
|
|
|
Parameters:
|
|
- prompt (str): The description of the video to be generated.
|
|
- model_path (str): The path of the pre-trained model to be used.
|
|
- output_path (str): The path where the generated video will be saved.
|
|
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
|
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
|
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
|
|
|
"""
|
|
|
|
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
|
quantize_(text_encoder, int8_weight_only())
|
|
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer",
|
|
torch_dtype=dtype)
|
|
quantize_(transformer, int8_weight_only())
|
|
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
|
quantize_(vae, int8_weight_only())
|
|
pipe = CogVideoXPipeline.from_pretrained(
|
|
model_path,
|
|
text_encoder=text_encoder,
|
|
transformer=transformer,
|
|
vae=vae,
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
pipe.enable_model_cpu_offload()
|
|
pipe.vae.enable_tiling()
|
|
video = pipe(
|
|
prompt=prompt,
|
|
num_videos_per_prompt=num_videos_per_prompt,
|
|
num_inference_steps=num_inference_steps,
|
|
num_frames=49,
|
|
guidance_scale=guidance_scale,
|
|
generator=torch.Generator().manual_seed(42),
|
|
).frames[0]
|
|
|
|
export_to_video(video, output_path, fps=8)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
|
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
|
parser.add_argument(
|
|
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
|
)
|
|
parser.add_argument(
|
|
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
|
)
|
|
parser.add_argument(
|
|
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
|
)
|
|
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
|
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
|
parser.add_argument(
|
|
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
|
generate_video(
|
|
prompt=args.prompt,
|
|
model_path=args.model_path,
|
|
output_path=args.output_path,
|
|
num_inference_steps=args.num_inference_steps,
|
|
guidance_scale=args.guidance_scale,
|
|
num_videos_per_prompt=args.num_videos_per_prompt,
|
|
dtype=dtype,
|
|
)
|