mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
[bugfix] fix specific resolution setting
Different models use different resolutions, for example, for the CogVideoX1.5 series models, the optimal generation resolution is 1360x768, But for CogVideoX, the best resolution is 720x480.
This commit is contained in:
parent
cfaca91cde
commit
92a589240f
@ -17,9 +17,9 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
|
||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
import logging
|
||||
import argparse
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
@ -31,6 +31,20 @@ from diffusers import (
|
||||
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Recommended resolution for each model (width, height)
|
||||
RESOLUTION_MAP = {
|
||||
# cogvideox1.5-*
|
||||
"cogvideox1.5-5b-i2v": (1360, 768),
|
||||
"cogvideox1.5-5b": (1360, 768),
|
||||
|
||||
# cogvideox-*
|
||||
"cogvideox-5b-i2v": (720, 480),
|
||||
"cogvideox-5b": (720, 480),
|
||||
"cogvideox-2b": (720, 480),
|
||||
}
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt: str,
|
||||
@ -38,8 +52,8 @@ def generate_video(
|
||||
lora_path: str = None,
|
||||
lora_rank: int = 128,
|
||||
num_frames: int = 81,
|
||||
width: int = 1360,
|
||||
height: int = 768,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
output_path: str = "./output.mp4",
|
||||
image_or_video_path: str = "",
|
||||
num_inference_steps: int = 50,
|
||||
@ -48,7 +62,7 @@ def generate_video(
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
||||
seed: int = 42,
|
||||
fps: int = 8,
|
||||
fps: int = 16,
|
||||
):
|
||||
"""
|
||||
Generates a video based on the given prompt and saves it to the specified path.
|
||||
@ -78,10 +92,19 @@ def generate_video(
|
||||
image = None
|
||||
video = None
|
||||
|
||||
if (width != 1360 or height != 768) and "cogvideox1.5-5b-i2v" in model_path.lower():
|
||||
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideoX1.5-5B-I2V. The best resolution for CogVideoX1.5-5B-I2V is 1360x768.")
|
||||
elif (width != 720 or height != 480) and "cogvideox-5b-i2v" in model_path.lower():
|
||||
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideo-5B-I2V. The best resolution for CogVideo-5B-I2V is 720x480.")
|
||||
model_name = model_path.split("/")[-1].lower()
|
||||
desired_resolution = RESOLUTION_MAP[model_name]
|
||||
if width is None or height is None:
|
||||
width, height = desired_resolution
|
||||
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
|
||||
elif (width, height) != desired_resolution:
|
||||
if generate_type == "i2v":
|
||||
# For i2v models, use user-defined width and height
|
||||
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
|
||||
else:
|
||||
# Otherwise, use the recommended width and height
|
||||
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
|
||||
width, height = desired_resolution
|
||||
|
||||
if generate_type == "i2v":
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
@ -132,6 +155,8 @@ def generate_video(
|
||||
).frames[0]
|
||||
elif generate_type == "t2v":
|
||||
video_generate = pipe(
|
||||
height=height,
|
||||
width=width,
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
@ -142,6 +167,8 @@ def generate_video(
|
||||
).frames[0]
|
||||
else:
|
||||
video_generate = pipe(
|
||||
height=height,
|
||||
width=width,
|
||||
prompt=prompt,
|
||||
video=video, # The path of the video to be used as the background of the video
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
@ -172,8 +199,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
|
||||
parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process")
|
||||
parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process")
|
||||
parser.add_argument("--width", type=int, default=None, help="Number of steps for the inference process")
|
||||
parser.add_argument("--height", type=int, default=None, help="Number of steps for the inference process")
|
||||
parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
|
||||
|
Loading…
x
Reference in New Issue
Block a user