CogVideo/inference/cli_demo_quantization.py

144 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
This script demonstrates how to generate a video from a text prompt using CogVideoX with quantization.
Note:
Must install the `torchao``torch`,`diffusers`,`accelerate` library FROM SOURCE to use the quantization feature.
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
ALL quantization schemes must use with NVIDIA GPUs.
# Run the script:
python cli_demo_quantization.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-2b --quantization_scheme fp8 --dtype float16
python cli_demo_quantization.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --quantization_scheme fp8 --dtype bfloat16
"""
import argparse
import os
import torch
import torch._dynamo
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler
from diffusers.utils import export_to_video
from transformers import T5EncoderModel
from torchao.quantization import quantize_, int8_weight_only
from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8
os.environ["TORCH_LOGS"] = "+dynamo,output_code,graph_breaks,recompiles"
torch._dynamo.config.suppress_errors = True
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
def quantize_model(part, quantization_scheme):
if quantization_scheme == "int8":
quantize_(part, int8_weight_only())
elif quantization_scheme == "fp8":
quantize_to_float8(part, QuantConfig(ActivationCasting.DYNAMIC))
return part
def generate_video(
prompt: str,
model_path: str,
output_path: str = "./output.mp4",
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
quantization_scheme: str = "fp8",
dtype: torch.dtype = torch.bfloat16,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- model_path (str): The path of the pre-trained model to be used.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
"""
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
pipe = CogVideoXPipeline.from_pretrained(
model_path,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
torch_dtype=dtype,
)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Using with compile will run faster. First time infer will cost ~30min to compile.
# pipe.transformer.to(memory_format=torch.channels_last)
# for FP8 should remove pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload()
# This is not for FP8 and INT8 and should remove this line
# pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
video = pipe(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
num_frames=49,
use_dynamic_cfg=True,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]
export_to_video(video, output_path, fps=8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16', 'bfloat16')"
)
parser.add_argument(
"--quantization_scheme",
type=str,
default="bf16",
choices=["int8", "fp8"],
help="The quantization scheme to use (int8, fp8)",
)
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
generate_video(
prompt=args.prompt,
model_path=args.model_path,
output_path=args.output_path,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos_per_prompt,
quantization_scheme=args.quantization_scheme,
dtype=dtype,
)