mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
fix convert
This commit is contained in:
parent
8e8275d2e8
commit
487a815219
@ -11,10 +11,30 @@ Run the script:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import tempfile
|
||||
from typing import Union, List
|
||||
|
||||
import PIL
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
|
||||
def export_to_video_imageio(
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||
) -> str:
|
||||
"""
|
||||
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
||||
"""
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
||||
if isinstance(video_frames[0], PIL.Image.Image):
|
||||
video_frames = [np.array(frame) for frame in video_frames]
|
||||
with imageio.get_writer(output_video_path, fps=fps) as writer:
|
||||
for frame in video_frames:
|
||||
writer.append_data(frame)
|
||||
return output_video_path
|
||||
|
||||
|
||||
def generate_video(
|
||||
@ -43,7 +63,7 @@ def generate_video(
|
||||
|
||||
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
|
||||
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
pipe.enable_sequential_cpu_offload() # Enable sequential CPU offload for faster inference
|
||||
|
||||
# Encode the prompt to get the prompt embeddings
|
||||
prompt_embeds, _ = pipe.encode_prompt(
|
||||
prompt=prompt, # The textual description for video generation
|
||||
@ -64,7 +84,7 @@ def generate_video(
|
||||
).frames[0]
|
||||
|
||||
# Export the generated frames to a video file. fps must be 8
|
||||
export_to_video(video, output_path, fps=8)
|
||||
export_to_video_imageio(video, output_path, fps=8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -36,16 +36,18 @@ def vae_demo(model_path, video_path, dtype, device):
|
||||
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
|
||||
# Load video frames
|
||||
video_reader = imageio.get_reader(video_path, 'ffmpeg')
|
||||
video_reader = imageio.get_reader(video_path, "ffmpeg")
|
||||
frames = []
|
||||
for frame in video_reader:
|
||||
frames.append(frame)
|
||||
video_reader.close()
|
||||
|
||||
# Transform frames to Tensor
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
|
||||
|
||||
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
|
||||
@ -84,9 +86,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
||||
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="./", help="The path to save the output video"
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
||||
)
|
||||
@ -100,4 +100,4 @@ if __name__ == "__main__":
|
||||
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
||||
|
||||
output = vae_demo(args.model_path, args.video_path, dtype, device)
|
||||
save_video(output, args.output_path)
|
||||
save_video(output, args.output_path)
|
||||
|
@ -6,5 +6,5 @@ opencv-python>=4.10
|
||||
imageio-ffmpeg>=0.5.1
|
||||
openai>=1.38.0
|
||||
transformers>=4.43.3
|
||||
accelerate>=0.33.0
|
||||
pillow==9.5.0
|
||||
sentencepiece>=0.2.0
|
||||
pillow==9.5.0
|
Loading…
x
Reference in New Issue
Block a user