add width and height

This commit is contained in:
zR 2024-11-12 00:17:19 +08:00
parent 2151a3bdfb
commit bb2cb130a0

View File

@ -37,6 +37,8 @@ def generate_video(
lora_path: str = None,
lora_rank: int = 128,
num_frames: int = 81,
width: int = 1360,
height: int = 768,
output_path: str = "./output.mp4",
image_or_video_path: str = "",
num_inference_steps: int = 50,
@ -58,6 +60,8 @@ def generate_video(
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- num_frames (int): Number of frames to generate. CogVideoX1.0 generates 49 frames for 6 seconds at 8 fps, while CogVideoX1.5 produces either 81 or 161 frames, corresponding to 5 seconds or 10 seconds at 16 fps.
- width (int): The width of the generated video, applicable only for CogVideoX1.5-5B-I2V
- height (int): The height of the generated video, applicable only for CogVideoX1.5-5B-I2V
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
@ -111,8 +115,11 @@ def generate_video(
# This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
if generate_type == "i2v":
video_generate = pipe(
height=height,
width=width,
prompt=prompt,
image=image, # The path of the image, the resolution of video will be the same as the image for CogVideoX1.5-5B-I2V, otherwise it will be 720 * 480
image=image,
# The path of the image, the resolution of video will be the same as the image for CogVideoX1.5-5B-I2V, otherwise it will be 720 * 480
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
num_inference_steps=num_inference_steps, # Number of inference steps
num_frames=num_frames, # Number of frames to generate
@ -162,6 +169,8 @@ if __name__ == "__main__":
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process")
parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process")
parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
@ -177,6 +186,8 @@ if __name__ == "__main__":
lora_rank=args.lora_rank,
output_path=args.output_path,
num_frames=args.num_frames,
width=args.width,
height=args.height,
image_or_video_path=args.image_or_video_path,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,