Merge pull request #615 from THUDM/CogVideoX_dev

Cog video x dev
This commit is contained in:
Yuxuan.Zhang 2024-12-19 12:57:56 +08:00 committed by GitHub
commit 7935bd58a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,9 +17,9 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
"""
import warnings
import logging
import argparse
from typing import Literal
from typing import Literal, Optional
import torch
from diffusers import (
@ -31,6 +31,20 @@ from diffusers import (
from diffusers.utils import export_to_video, load_image, load_video
logging.basicConfig(level=logging.INFO)
# Recommended resolution for each model (width, height)
RESOLUTION_MAP = {
# cogvideox1.5-*
"cogvideox1.5-5b-i2v": (1360, 768),
"cogvideox1.5-5b": (1360, 768),
# cogvideox-*
"cogvideox-5b-i2v": (720, 480),
"cogvideox-5b": (720, 480),
"cogvideox-2b": (720, 480),
}
def generate_video(
prompt: str,
@ -38,8 +52,8 @@ def generate_video(
lora_path: str = None,
lora_rank: int = 128,
num_frames: int = 81,
width: int = 1360,
height: int = 768,
width: Optional[int] = None,
height: Optional[int] = None,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
num_inference_steps: int = 50,
@ -48,7 +62,7 @@ def generate_video(
dtype: torch.dtype = torch.bfloat16,
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
seed: int = 42,
fps: int = 8,
fps: int = 16,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
@ -78,10 +92,19 @@ def generate_video(
image = None
video = None
if (width != 1360 or height != 768) and "cogvideox1.5-5b-i2v" in model_path.lower():
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideoX1.5-5B-I2V. The best resolution for CogVideoX1.5-5B-I2V is 1360x768.")
elif (width != 720 or height != 480) and "cogvideox-5b-i2v" in model_path.lower():
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideo-5B-I2V. The best resolution for CogVideo-5B-I2V is 720x480.")
model_name = model_path.split("/")[-1].lower()
desired_resolution = RESOLUTION_MAP[model_name]
if width is None or height is None:
width, height = desired_resolution
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
elif (width, height) != desired_resolution:
if generate_type == "i2v":
# For i2v models, use user-defined width and height
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
else:
# Otherwise, use the recommended width and height
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
width, height = desired_resolution
if generate_type == "i2v":
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
@ -132,6 +155,8 @@ def generate_video(
).frames[0]
elif generate_type == "t2v":
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
@ -142,6 +167,8 @@ def generate_video(
).frames[0]
else:
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
video=video, # The path of the video to be used as the background of the video
num_videos_per_prompt=num_videos_per_prompt,
@ -172,9 +199,9 @@ if __name__ == "__main__":
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process")
parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process")
parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
parser.add_argument("--height", type=int, default=None, help="The height of the generated video")
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")