mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-05-29 00:29:23 +08:00
fix convert
This commit is contained in:
parent
8e8275d2e8
commit
487a815219
@ -11,10 +11,30 @@ Run the script:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import tempfile
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import CogVideoXPipeline
|
from diffusers import CogVideoXPipeline
|
||||||
from diffusers.utils import export_to_video
|
|
||||||
|
|
||||||
|
def export_to_video_imageio(
|
||||||
|
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
||||||
|
"""
|
||||||
|
if output_video_path is None:
|
||||||
|
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
||||||
|
if isinstance(video_frames[0], PIL.Image.Image):
|
||||||
|
video_frames = [np.array(frame) for frame in video_frames]
|
||||||
|
with imageio.get_writer(output_video_path, fps=fps) as writer:
|
||||||
|
for frame in video_frames:
|
||||||
|
writer.append_data(frame)
|
||||||
|
return output_video_path
|
||||||
|
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
@ -43,7 +63,7 @@ def generate_video(
|
|||||||
|
|
||||||
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
|
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
|
||||||
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||||
pipe.enable_sequential_cpu_offload() # Enable sequential CPU offload for faster inference
|
|
||||||
# Encode the prompt to get the prompt embeddings
|
# Encode the prompt to get the prompt embeddings
|
||||||
prompt_embeds, _ = pipe.encode_prompt(
|
prompt_embeds, _ = pipe.encode_prompt(
|
||||||
prompt=prompt, # The textual description for video generation
|
prompt=prompt, # The textual description for video generation
|
||||||
@ -64,7 +84,7 @@ def generate_video(
|
|||||||
).frames[0]
|
).frames[0]
|
||||||
|
|
||||||
# Export the generated frames to a video file. fps must be 8
|
# Export the generated frames to a video file. fps must be 8
|
||||||
export_to_video(video, output_path, fps=8)
|
export_to_video_imageio(video, output_path, fps=8)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -36,16 +36,18 @@ def vae_demo(model_path, video_path, dtype, device):
|
|||||||
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||||
|
|
||||||
# Load video frames
|
# Load video frames
|
||||||
video_reader = imageio.get_reader(video_path, 'ffmpeg')
|
video_reader = imageio.get_reader(video_path, "ffmpeg")
|
||||||
frames = []
|
frames = []
|
||||||
for frame in video_reader:
|
for frame in video_reader:
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
video_reader.close()
|
video_reader.close()
|
||||||
|
|
||||||
# Transform frames to Tensor
|
# Transform frames to Tensor
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
|
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
|
||||||
|
|
||||||
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
|
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
|
||||||
@ -84,9 +86,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
|
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
|
||||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
||||||
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
|
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
|
||||||
parser.add_argument(
|
parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
|
||||||
"--output_path", type=str, default="./", help="The path to save the output video"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
||||||
)
|
)
|
||||||
|
@ -6,5 +6,5 @@ opencv-python>=4.10
|
|||||||
imageio-ffmpeg>=0.5.1
|
imageio-ffmpeg>=0.5.1
|
||||||
openai>=1.38.0
|
openai>=1.38.0
|
||||||
transformers>=4.43.3
|
transformers>=4.43.3
|
||||||
accelerate>=0.33.0
|
sentencepiece>=0.2.0
|
||||||
pillow==9.5.0
|
pillow==9.5.0
|
Loading…
x
Reference in New Issue
Block a user