fix import order and deprecate for CVX 2B models

This commit is contained in:
LittleNyima 2025-02-26 15:54:58 +08:00
parent d6bb910697
commit 2c33c0982b

View File

@ -3,8 +3,17 @@ This script performs DDIM inversion for video frames using a pre-trained model a
a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to
process video frames, apply the DDIM inverse scheduler, and produce an output video.
**Please notice that this script is based on the CogVideoX 5B model, and would not generate
a good result for 2B variants.**
Usage:
python script.py --model-path /path/to/model --prompt "a prompt" --video-path /path/to/video.mp4 --output-path /path/to/output
python ddim_inversion.py
--model-path /path/to/model
--prompt "a prompt"
--video-path /path/to/video.mp4
--output-path /path/to/output
For more details about the cli arguments, please run `python ddim_inversion.py --help`.
Author:
LittleNyima <littlenyima[at]163[dot]com>
@ -15,7 +24,6 @@ import math
import os
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
import decord
import torch
import torch.nn.functional as F
import torchvision.transforms as T
@ -27,6 +35,10 @@ from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, r
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
from diffusers.utils import export_to_video
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error.
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort: skip
class DDIMInversionArguments(TypedDict):
model_path: str
@ -399,6 +411,8 @@ def ddim_inversion(
device: torch.device,
):
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device)
if not pipeline.transformer.config.use_rotary_positional_embeddings:
raise NotImplementedError("This script supports CogVideoX 5B model only.")
video_frames = get_video_frames(
video_path=video_path,
width=width,