mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-18 03:10:00 +08:00
update I2V infer code and draft readme
This commit is contained in:
parent
70ff60c925
commit
098640337d
@ -21,26 +21,28 @@ import argparse
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (CogVideoXPipeline,
|
from diffusers import (
|
||||||
CogVideoXDDIMScheduler,
|
CogVideoXPipeline,
|
||||||
CogVideoXDPMScheduler,
|
CogVideoXDDIMScheduler,
|
||||||
CogVideoXImageToVideoPipeline,
|
CogVideoXDPMScheduler,
|
||||||
CogVideoXVideoToVideoPipeline)
|
CogVideoXImageToVideoPipeline,
|
||||||
|
CogVideoXVideoToVideoPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
from diffusers.utils import export_to_video, load_image, load_video
|
from diffusers.utils import export_to_video, load_image, load_video
|
||||||
|
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_path: str = "./output.mp4",
|
output_path: str = "./output.mp4",
|
||||||
image_or_video_path: str = "",
|
image_or_video_path: str = "",
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 6.0,
|
guidance_scale: float = 6.0,
|
||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a video based on the given prompt and saves it to the specified path.
|
Generates a video based on the given prompt and saves it to the specified path.
|
||||||
@ -53,7 +55,7 @@ def generate_video(
|
|||||||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||||
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').
|
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
|
||||||
- seed (int): The seed for reproducibility.
|
- seed (int): The seed for reproducibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -97,13 +99,13 @@ def generate_video(
|
|||||||
image=image, # The path of the image to be used as the background of the video
|
image=image, # The path of the image to be used as the background of the video
|
||||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||||
num_inference_steps=num_inference_steps, # Number of inference steps
|
num_inference_steps=num_inference_steps, # Number of inference steps
|
||||||
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after.
|
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after.
|
||||||
use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False
|
use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||||
).frames[0]
|
).frames[0]
|
||||||
elif generate_type == "t2v":
|
elif generate_type == "t2v":
|
||||||
video_generate = pipe(
|
video_generate = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_videos_per_prompt=num_videos_per_prompt,
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
@ -130,19 +132,29 @@ def generate_video(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||||
parser.add_argument("--image_or_video_path", type=str, default=None,
|
parser.add_argument(
|
||||||
help="The path of the image to be used as the background of the video")
|
"--image_or_video_path",
|
||||||
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b",
|
type=str,
|
||||||
help="The path of the pre-trained model to be used")
|
default=None,
|
||||||
parser.add_argument("--output_path", type=str, default="./output.mp4",
|
help="The path of the image to be used as the background of the video",
|
||||||
help="The path where the generated video will be saved")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
||||||
|
)
|
||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of steps for the inference process")
|
parser.add_argument(
|
||||||
|
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
||||||
|
)
|
||||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||||
parser.add_argument("--generate_type", type=str, default="t2v",
|
parser.add_argument(
|
||||||
help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')")
|
"--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16",
|
)
|
||||||
help="The data type for computation (e.g., 'float16' or 'bfloat16')")
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import logging
|
import logging
|
||||||
import skvideo.io
|
import skvideo.io
|
||||||
from rife.RIFE_HDv3 import Model
|
from rife.RIFE_HDv3 import Model
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@ -19,9 +19,8 @@ def pad_image(img, scale):
|
|||||||
tmp = max(32, int(32 / scale))
|
tmp = max(32, int(32 / scale))
|
||||||
ph = ((h - 1) // tmp + 1) * tmp
|
ph = ((h - 1) // tmp + 1) * tmp
|
||||||
pw = ((w - 1) // tmp + 1) * tmp
|
pw = ((w - 1) // tmp + 1) * tmp
|
||||||
padding = (0, pw - w, 0, ph - h)
|
padding = (0, 0, pw - w, ph - h)
|
||||||
|
return F.pad(img, padding)
|
||||||
return F.pad(img, padding), padding
|
|
||||||
|
|
||||||
|
|
||||||
def make_inference(model, I0, I1, upscale_amount, n):
|
def make_inference(model, I0, I1, upscale_amount, n):
|
||||||
@ -45,15 +44,9 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
frame = samples[b : b + 1]
|
frame = samples[b : b + 1]
|
||||||
_, _, h, w = frame.shape
|
_, _, h, w = frame.shape
|
||||||
|
|
||||||
I0 = samples[b : b + 1]
|
I0 = samples[b : b + 1]
|
||||||
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
|
||||||
|
I1 = pad_image(I1, upscale_amount)
|
||||||
I0, padding = pad_image(I0, upscale_amount)
|
|
||||||
I0 = I0.to(torch.float)
|
|
||||||
I1, _ = pad_image(I1, upscale_amount)
|
|
||||||
I1 = I1.to(torch.float)
|
|
||||||
|
|
||||||
# [c, h, w]
|
# [c, h, w]
|
||||||
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
||||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||||
@ -61,24 +54,14 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||||
|
|
||||||
if ssim > 0.996:
|
if ssim > 0.996:
|
||||||
I1 = samples[b : b + 1]
|
I1 = I0
|
||||||
# print(f'upscale_amount:{upscale_amount}')
|
I1 = pad_image(I1, upscale_amount)
|
||||||
# print(f'ssim:{upscale_amount}')
|
|
||||||
# print(f'I0 shape:{I0.shape}')
|
|
||||||
# print(f'I1 shape:{I1.shape}')
|
|
||||||
I1, padding = pad_image(I1, upscale_amount)
|
|
||||||
# print(f'I0 shape:{I0.shape}')
|
|
||||||
# print(f'I1 shape:{I1.shape}')
|
|
||||||
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
I1 = make_inference(model, I0, I1, upscale_amount, 1)
|
||||||
|
|
||||||
# print(f'I0 shape:{I0.shape}')
|
I1_small = F.interpolate(I1[0], (32, 32), mode="bilinear", align_corners=False)
|
||||||
# print(f'I1[0] shape:{I1[0].shape}')
|
|
||||||
I1 = I1[0]
|
|
||||||
|
|
||||||
# print(f'I1[0] unpadded shape:{I1.shape}')
|
|
||||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
|
||||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||||
frame = I1[padding[0]:, padding[2]:, padding[3]:,padding[1]:]
|
frame = I1[0]
|
||||||
|
I1 = I1[0]
|
||||||
|
|
||||||
tmp_output = []
|
tmp_output = []
|
||||||
if ssim < 0.2:
|
if ssim < 0.2:
|
||||||
@ -88,13 +71,9 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
else:
|
else:
|
||||||
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
|
tmp_output = make_inference(model, I0, I1, upscale_amount, 2**exp - 1) if exp else []
|
||||||
|
|
||||||
frame, _ = pad_image(frame, upscale_amount)
|
frame = pad_image(frame, upscale_amount)
|
||||||
print(f'frame shape:{frame.shape}')
|
|
||||||
print(f'tmp_output[0] shape:{tmp_output[0].shape}')
|
|
||||||
tmp_output = [frame] + tmp_output
|
tmp_output = [frame] + tmp_output
|
||||||
|
for i, frame in enumerate(tmp_output):
|
||||||
for i, frame in enumerate(tmp_output):
|
|
||||||
frame = F.interpolate(frame, size=(h, w))
|
|
||||||
output.append(frame.to(output_device))
|
output.append(frame.to(output_device))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -117,26 +96,14 @@ def frame_generator(video_capture):
|
|||||||
|
|
||||||
|
|
||||||
def rife_inference_with_path(model, video_path):
|
def rife_inference_with_path(model, video_path):
|
||||||
# Open the video file
|
|
||||||
video_capture = cv2.VideoCapture(video_path)
|
video_capture = cv2.VideoCapture(video_path)
|
||||||
fps = video_capture.get(cv2.CAP_PROP_FPS) # Get the frames per second
|
tot_frame = video_capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||||
tot_frame = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # Total frames in the video
|
|
||||||
pt_frame_data = []
|
pt_frame_data = []
|
||||||
pt_frame = skvideo.io.vreader(video_path)
|
pt_frame = skvideo.io.vreader(video_path)
|
||||||
# Cyclic reading of the video frames
|
for frame in pt_frame:
|
||||||
while video_capture.isOpened():
|
|
||||||
ret, frame = video_capture.read()
|
|
||||||
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
# BGR to RGB
|
|
||||||
frame_rgb = frame[..., ::-1]
|
|
||||||
frame_rgb = frame_rgb.copy()
|
|
||||||
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
|
||||||
pt_frame_data.append(
|
pt_frame_data.append(
|
||||||
tensor.permute(2, 0, 1)
|
torch.from_numpy(np.transpose(frame, (2, 0, 1))).to("cpu", non_blocking=True).float() / 255.0
|
||||||
) # to [c, h, w,]
|
)
|
||||||
|
|
||||||
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
||||||
pt_frame = pt_frame.to(device)
|
pt_frame = pt_frame.to(device)
|
||||||
@ -164,11 +131,3 @@ def rife_inference_with_latents(model, latents):
|
|||||||
rife_results.append(pt_image)
|
rife_results.append(pt_image)
|
||||||
|
|
||||||
return torch.stack(rife_results)
|
return torch.stack(rife_results)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
|
||||||
model = load_rife_model("model_rife")
|
|
||||||
|
|
||||||
video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/sat/configs/outputs/1_In_the_heart_of_a_bustling_city,_a_young_woman_with_long,_flowing_brown_hair_and_a_radiant_smile_stands_out._She's_donne/0/000000.mp4")
|
|
||||||
print(video_path)
|
|
@ -19,8 +19,9 @@ from openai import OpenAI
|
|||||||
import moviepy.editor as mp
|
import moviepy.editor as mp
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
device = "cuda" # Need to use cuda
|
||||||
|
|
||||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=dtype)
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=dtype).to(device)
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
pipe.vae.enable_slicing()
|
pipe.vae.enable_slicing()
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested
|
diffusers>=0.30.3
|
||||||
transformers>=4.44.2 # The development team is working on version 4.44.2
|
accelerate>=0.34.2
|
||||||
accelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested
|
transformers>=4.44.2
|
||||||
sentencepiece>=0.2.0 # T5 used
|
|
||||||
SwissArmyTransformer>=0.4.12
|
|
||||||
numpy==1.26.0
|
numpy==1.26.0
|
||||||
torch>=2.4.0 # Tested in 2.2 2.3 2.4 and 2.5, The development team is working on version 2.4.0.
|
torch==2.4.0
|
||||||
torchvision>=0.19.0 # The development team is working on version 0.19.0.
|
torchvision==0.19.0
|
||||||
gradio>=4.42.0 # For HF gradio demo
|
sentencepiece==0.2.0
|
||||||
streamlit>=1.38.0 # For streamlit web demo
|
SwissArmyTransformer>=0.4.12
|
||||||
imageio==2.34.2 # For diffusers inference export video
|
gradio>=4.44.0
|
||||||
imageio-ffmpeg==0.5.1 # For diffusers inference export video
|
streamlit>=1.38.0
|
||||||
openai>=1.42.0 # For prompt refiner
|
imageio>=2.35.1
|
||||||
moviepy==1.0.3 # For export video
|
imageio-ffmpeg>=0.5.1
|
||||||
|
openai>=1.45.0
|
||||||
|
moviepy==1.0.3
|
||||||
pillow==9.5.0
|
pillow==9.5.0
|
@ -323,7 +323,6 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
if isinstance(c[k], torch.Tensor):
|
if isinstance(c[k], torch.Tensor):
|
||||||
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
||||||
|
|
||||||
|
|
||||||
if self.noised_image_input:
|
if self.noised_image_input:
|
||||||
image = x[:, :, 0:1]
|
image = x[:, :, 0:1]
|
||||||
image = self.add_noise_to_first_frame(image)
|
image = self.add_noise_to_first_frame(image)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
#! /bin/bash
|
#! /bin/bash
|
||||||
|
|
||||||
echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
|
||||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_5b_i2v_lora.yaml configs/sft.yaml --seed $RANDOM"
|
run_cmd="PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --standalone --nproc_per_node=8 train_video.py --base configs/test_cogvideox_5b_i2v_lora.yaml configs/test_sft.yaml --seed $RANDOM"
|
||||||
|
|
||||||
echo ${run_cmd}
|
echo ${run_cmd}
|
||||||
eval ${run_cmd}
|
eval ${run_cmd}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user