diff --git a/inference/cli_demo.py b/inference/cli_demo.py index ea8b4fc..f4dbc28 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -17,9 +17,9 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. """ -import warnings +import logging import argparse -from typing import Literal +from typing import Literal, Optional import torch from diffusers import ( @@ -31,6 +31,20 @@ from diffusers import ( from diffusers.utils import export_to_video, load_image, load_video +logging.basicConfig(level=logging.INFO) + +# Recommended resolution for each model (width, height) +RESOLUTION_MAP = { + # cogvideox1.5-* + "cogvideox1.5-5b-i2v": (1360, 768), + "cogvideox1.5-5b": (1360, 768), + + # cogvideox-* + "cogvideox-5b-i2v": (720, 480), + "cogvideox-5b": (720, 480), + "cogvideox-2b": (720, 480), +} + def generate_video( prompt: str, @@ -38,8 +52,8 @@ def generate_video( lora_path: str = None, lora_rank: int = 128, num_frames: int = 81, - width: int = 1360, - height: int = 768, + width: Optional[int] = None, + height: Optional[int] = None, output_path: str = "./output.mp4", image_or_video_path: str = "", num_inference_steps: int = 50, @@ -48,7 +62,7 @@ def generate_video( dtype: torch.dtype = torch.bfloat16, generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video seed: int = 42, - fps: int = 8, + fps: int = 16, ): """ Generates a video based on the given prompt and saves it to the specified path. @@ -78,10 +92,19 @@ def generate_video( image = None video = None - if (width != 1360 or height != 768) and "cogvideox1.5-5b-i2v" in model_path.lower(): - warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideoX1.5-5B-I2V. The best resolution for CogVideoX1.5-5B-I2V is 1360x768.") - elif (width != 720 or height != 480) and "cogvideox-5b-i2v" in model_path.lower(): - warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideo-5B-I2V. The best resolution for CogVideo-5B-I2V is 720x480.") + model_name = model_path.split("/")[-1].lower() + desired_resolution = RESOLUTION_MAP[model_name] + if width is None or height is None: + width, height = desired_resolution + logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m") + elif (width, height) != desired_resolution: + if generate_type == "i2v": + # For i2v models, use user-defined width and height + logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m") + else: + # Otherwise, use the recommended width and height + logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m") + width, height = desired_resolution if generate_type == "i2v": pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) @@ -132,6 +155,8 @@ def generate_video( ).frames[0] elif generate_type == "t2v": video_generate = pipe( + height=height, + width=width, prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, num_inference_steps=num_inference_steps, @@ -142,6 +167,8 @@ def generate_video( ).frames[0] else: video_generate = pipe( + height=height, + width=width, prompt=prompt, video=video, # The path of the video to be used as the background of the video num_videos_per_prompt=num_videos_per_prompt, @@ -172,8 +199,8 @@ if __name__ == "__main__": parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process") - parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process") - parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process") + parser.add_argument("--width", type=int, default=None, help="Number of steps for the inference process") + parser.add_argument("--height", type=int, default=None, help="Number of steps for the inference process") parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process") parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")