mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
commit
7935bd58a1
@ -17,9 +17,9 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
|
||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
import logging
|
||||
import argparse
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
@ -31,6 +31,20 @@ from diffusers import (
|
||||
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Recommended resolution for each model (width, height)
|
||||
RESOLUTION_MAP = {
|
||||
# cogvideox1.5-*
|
||||
"cogvideox1.5-5b-i2v": (1360, 768),
|
||||
"cogvideox1.5-5b": (1360, 768),
|
||||
|
||||
# cogvideox-*
|
||||
"cogvideox-5b-i2v": (720, 480),
|
||||
"cogvideox-5b": (720, 480),
|
||||
"cogvideox-2b": (720, 480),
|
||||
}
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt: str,
|
||||
@ -38,8 +52,8 @@ def generate_video(
|
||||
lora_path: str = None,
|
||||
lora_rank: int = 128,
|
||||
num_frames: int = 81,
|
||||
width: int = 1360,
|
||||
height: int = 768,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
output_path: str = "./output.mp4",
|
||||
image_or_video_path: str = "",
|
||||
num_inference_steps: int = 50,
|
||||
@ -48,7 +62,7 @@ def generate_video(
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
||||
seed: int = 42,
|
||||
fps: int = 8,
|
||||
fps: int = 16,
|
||||
):
|
||||
"""
|
||||
Generates a video based on the given prompt and saves it to the specified path.
|
||||
@ -78,10 +92,19 @@ def generate_video(
|
||||
image = None
|
||||
video = None
|
||||
|
||||
if (width != 1360 or height != 768) and "cogvideox1.5-5b-i2v" in model_path.lower():
|
||||
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideoX1.5-5B-I2V. The best resolution for CogVideoX1.5-5B-I2V is 1360x768.")
|
||||
elif (width != 720 or height != 480) and "cogvideox-5b-i2v" in model_path.lower():
|
||||
warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideo-5B-I2V. The best resolution for CogVideo-5B-I2V is 720x480.")
|
||||
model_name = model_path.split("/")[-1].lower()
|
||||
desired_resolution = RESOLUTION_MAP[model_name]
|
||||
if width is None or height is None:
|
||||
width, height = desired_resolution
|
||||
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
|
||||
elif (width, height) != desired_resolution:
|
||||
if generate_type == "i2v":
|
||||
# For i2v models, use user-defined width and height
|
||||
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
|
||||
else:
|
||||
# Otherwise, use the recommended width and height
|
||||
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
|
||||
width, height = desired_resolution
|
||||
|
||||
if generate_type == "i2v":
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
@ -132,6 +155,8 @@ def generate_video(
|
||||
).frames[0]
|
||||
elif generate_type == "t2v":
|
||||
video_generate = pipe(
|
||||
height=height,
|
||||
width=width,
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
@ -142,6 +167,8 @@ def generate_video(
|
||||
).frames[0]
|
||||
else:
|
||||
video_generate = pipe(
|
||||
height=height,
|
||||
width=width,
|
||||
prompt=prompt,
|
||||
video=video, # The path of the video to be used as the background of the video
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
@ -172,9 +199,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
|
||||
parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process")
|
||||
parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process")
|
||||
parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
|
||||
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
|
||||
parser.add_argument("--height", type=int, default=None, help="The height of the generated video")
|
||||
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video")
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
|
||||
|
Loading…
x
Reference in New Issue
Block a user