[bugfix] fix specific resolution setting

Different models use different resolutions, for example, for the CogVideoX1.5 series models, the optimal generation resolution is 1360x768, But for CogVideoX, the best resolution is 720x480.
This commit is contained in:
OleehyO 2024-12-18 07:38:10 +00:00
parent cfaca91cde
commit 92a589240f

View File

@ -17,9 +17,9 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
"""
import warnings
import logging
import argparse
from typing import Literal
from typing import Literal, Optional
import torch
from diffusers import (
@ -31,6 +31,20 @@ from diffusers import (
from diffusers.utils import export_to_video, load_image, load_video
logging.basicConfig(level=logging.INFO)
# Recommended resolution for each model (width, height)
RESOLUTION_MAP = {
# cogvideox1.5-*
"cogvideox1.5-5b-i2v": (1360, 768),
"cogvideox1.5-5b": (1360, 768),
# cogvideox-*
"cogvideox-5b-i2v": (720, 480),
"cogvideox-5b": (720, 480),
"cogvideox-2b": (720, 480),
}
def generate_video(
prompt: str,
@ -38,8 +52,8 @@ def generate_video(
lora_path: str = None,
lora_rank: int = 128,
num_frames: int = 81,
width: int = 1360,
height: int = 768,
width: Optional[int] = None,
height: Optional[int] = None,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
num_inference_steps: int = 50,
@ -48,7 +62,7 @@ def generate_video(
dtype: torch.dtype = torch.bfloat16,
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
seed: int = 42,
fps: int = 8,
fps: int = 16,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
@ -78,10 +92,19 @@ def generate_video(
image = None
video = None
if (width != 1360 or height != 768) and "cogvideox1.5-5b-i2v" in model_path.lower():
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideoX1.5-5B-I2V. The best resolution for CogVideoX1.5-5B-I2V is 1360x768.")
elif (width != 720 or height != 480) and "cogvideox-5b-i2v" in model_path.lower():
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideo-5B-I2V. The best resolution for CogVideo-5B-I2V is 720x480.")
model_name = model_path.split("/")[-1].lower()
desired_resolution = RESOLUTION_MAP[model_name]
if width is None or height is None:
width, height = desired_resolution
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
elif (width, height) != desired_resolution:
if generate_type == "i2v":
# For i2v models, use user-defined width and height
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
else:
# Otherwise, use the recommended width and height
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
width, height = desired_resolution
if generate_type == "i2v":
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
@ -132,6 +155,8 @@ def generate_video(
).frames[0]
elif generate_type == "t2v":
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
num_inference_steps=num_inference_steps,
@ -142,6 +167,8 @@ def generate_video(
).frames[0]
else:
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
video=video, # The path of the video to be used as the background of the video
num_videos_per_prompt=num_videos_per_prompt,
@ -172,8 +199,8 @@ if __name__ == "__main__":
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process")
parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process")
parser.add_argument("--width", type=int, default=None, help="Number of steps for the inference process")
parser.add_argument("--height", type=int, default=None, help="Number of steps for the inference process")
parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")