mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
162 lines
6.1 KiB
Python
162 lines
6.1 KiB
Python
"""
|
||
This script demonstrates how to generate a video from a text prompt using CogVideoX with quantization.
|
||
|
||
Note:
|
||
|
||
Must install the `torchao`,`torch` library FROM SOURCE to use the quantization feature.
|
||
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
|
||
|
||
ALL quantization schemes must use with NVIDIA GPUs.
|
||
|
||
# Run the script:
|
||
|
||
python cli_demo_quantization.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-2b --quantization_scheme fp8 --dtype float16
|
||
python cli_demo_quantization.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --quantization_scheme fp8 --dtype bfloat16
|
||
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import torch
|
||
import torch._dynamo
|
||
from diffusers import (
|
||
AutoencoderKLCogVideoX,
|
||
CogVideoXTransformer3DModel,
|
||
CogVideoXPipeline,
|
||
CogVideoXDPMScheduler,
|
||
)
|
||
from diffusers.utils import export_to_video
|
||
from transformers import T5EncoderModel
|
||
from torchao.quantization import quantize_, int8_weight_only
|
||
from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8
|
||
|
||
os.environ["TORCH_LOGS"] = "+dynamo,output_code,graph_breaks,recompiles"
|
||
torch._dynamo.config.suppress_errors = True
|
||
torch.set_float32_matmul_precision("high")
|
||
torch._inductor.config.conv_1x1_as_mm = True
|
||
torch._inductor.config.coordinate_descent_tuning = True
|
||
torch._inductor.config.epilogue_fusion = False
|
||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||
|
||
|
||
def quantize_model(part, quantization_scheme):
|
||
if quantization_scheme == "int8":
|
||
quantize_(part, int8_weight_only())
|
||
elif quantization_scheme == "fp8":
|
||
quantize_to_float8(part, QuantConfig(ActivationCasting.DYNAMIC))
|
||
return part
|
||
|
||
|
||
def generate_video(
|
||
prompt: str,
|
||
model_path: str,
|
||
output_path: str = "./output.mp4",
|
||
num_inference_steps: int = 50,
|
||
guidance_scale: float = 6.0,
|
||
num_videos_per_prompt: int = 1,
|
||
quantization_scheme: str = "fp8",
|
||
dtype: torch.dtype = torch.bfloat16,
|
||
num_frames: int = 81,
|
||
fps: int = 8,
|
||
seed: int = 42,
|
||
):
|
||
"""
|
||
Generates a video based on the given prompt and saves it to the specified path.
|
||
|
||
Parameters:
|
||
- prompt (str): The description of the video to be generated.
|
||
- model_path (str): The path of the pre-trained model to be used.
|
||
- output_path (str): The path where the generated video will be saved.
|
||
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
|
||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||
"""
|
||
text_encoder = T5EncoderModel.from_pretrained(
|
||
model_path, subfolder="text_encoder", torch_dtype=dtype
|
||
)
|
||
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
|
||
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||
model_path, subfolder="transformer", torch_dtype=dtype
|
||
)
|
||
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
|
||
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
||
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
|
||
pipe = CogVideoXPipeline.from_pretrained(
|
||
model_path,
|
||
text_encoder=text_encoder,
|
||
transformer=transformer,
|
||
vae=vae,
|
||
torch_dtype=dtype,
|
||
)
|
||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||
pipe.scheduler.config, timestep_spacing="trailing"
|
||
)
|
||
pipe.enable_model_cpu_offload()
|
||
pipe.vae.enable_slicing()
|
||
pipe.vae.enable_tiling()
|
||
|
||
video = pipe(
|
||
prompt=prompt,
|
||
num_videos_per_prompt=num_videos_per_prompt,
|
||
num_inference_steps=num_inference_steps,
|
||
num_frames=num_frames,
|
||
use_dynamic_cfg=True,
|
||
guidance_scale=guidance_scale,
|
||
generator=torch.Generator(device="cuda").manual_seed(seed),
|
||
).frames[0]
|
||
|
||
export_to_video(video, output_path, fps=fps)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(
|
||
description="Generate a video from a text prompt using CogVideoX"
|
||
)
|
||
parser.add_argument(
|
||
"--prompt", type=str, required=True, help="The description of the video to be generated"
|
||
)
|
||
parser.add_argument(
|
||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model"
|
||
)
|
||
parser.add_argument(
|
||
"--output_path", type=str, default="./output.mp4", help="Path to save generated video"
|
||
)
|
||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||
parser.add_argument(
|
||
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
|
||
)
|
||
parser.add_argument(
|
||
"--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt"
|
||
)
|
||
parser.add_argument(
|
||
"--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')"
|
||
)
|
||
parser.add_argument(
|
||
"--quantization_scheme",
|
||
type=str,
|
||
default="fp8",
|
||
choices=["int8", "fp8"],
|
||
help="Quantization scheme",
|
||
)
|
||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
|
||
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
|
||
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
||
|
||
args = parser.parse_args()
|
||
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
||
generate_video(
|
||
prompt=args.prompt,
|
||
model_path=args.model_path,
|
||
output_path=args.output_path,
|
||
num_inference_steps=args.num_inference_steps,
|
||
guidance_scale=args.guidance_scale,
|
||
num_videos_per_prompt=args.num_videos_per_prompt,
|
||
quantization_scheme=args.quantization_scheme,
|
||
dtype=dtype,
|
||
num_frames=args.num_frames,
|
||
fps=args.fps,
|
||
seed=args.seed,
|
||
)
|