diff --git a/inference/ddim_inversion.py b/inference/ddim_inversion.py index 7be35b3..e932bf4 100644 --- a/inference/ddim_inversion.py +++ b/inference/ddim_inversion.py @@ -3,8 +3,17 @@ This script performs DDIM inversion for video frames using a pre-trained model a a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to process video frames, apply the DDIM inverse scheduler, and produce an output video. +**Please notice that this script is based on the CogVideoX 5B model, and would not generate +a good result for 2B variants.** + Usage: - python script.py --model-path /path/to/model --prompt "a prompt" --video-path /path/to/video.mp4 --output-path /path/to/output + python ddim_inversion.py + --model-path /path/to/model + --prompt "a prompt" + --video-path /path/to/video.mp4 + --output-path /path/to/output + +For more details about the cli arguments, please run `python ddim_inversion.py --help`. Author: LittleNyima @@ -15,7 +24,6 @@ import math import os from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast -import decord import torch import torch.nn.functional as F import torchvision.transforms as T @@ -27,6 +35,10 @@ from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, r from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler from diffusers.utils import export_to_video +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error. +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort: skip + class DDIMInversionArguments(TypedDict): model_path: str @@ -399,6 +411,8 @@ def ddim_inversion( device: torch.device, ): pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device) + if not pipeline.transformer.config.use_rotary_positional_embeddings: + raise NotImplementedError("This script supports CogVideoX 5B model only.") video_frames = get_video_frames( video_path=video_path, width=width,