mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
vae demo update
This commit is contained in:
parent
032180bb73
commit
18c1669a8e
@ -1,14 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers.
|
This script is designed to demonstrate how to use the CogVideoX-2b VAE model for video encoding and decoding.
|
||||||
|
It allows you to encode a video into a latent representation, decode it back into a video, or perform both operations sequentially.
|
||||||
|
Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the `{your local diffusers path}` argument to the path of the cloned repository.
|
||||||
|
|
||||||
Note:
|
Command 1: Encoding Video
|
||||||
This script requires the `diffusers>=0.30.0` library to be installed.
|
Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-2b VAE model.
|
||||||
If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon.
|
Memory Usage: ~34GB of GPU memory for encoding.
|
||||||
Cost 71GB of GPU memory for encoding a 6s video at 720p resolution.
|
If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder and you can still run the decoding command.
|
||||||
|
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode encode
|
||||||
|
|
||||||
Run the script:
|
Command 2: Decoding Video
|
||||||
$ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output
|
|
||||||
|
|
||||||
|
Decodes the latent representation stored in encoded.pt back into a video.
|
||||||
|
Memory Usage: ~19GB of GPU memory for decoding.
|
||||||
|
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --encoded_path ./encoded.pt --mode decode
|
||||||
|
|
||||||
|
Command 3: Encoding and Decoding Video
|
||||||
|
Encodes the video located at ../resources/videos/1.mp4 and then immediately decodes it.
|
||||||
|
Memory Usage: 34GB for encoding + 19GB for decoding (sequentially).
|
||||||
|
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode both
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -19,7 +29,7 @@ from diffusers import AutoencoderKLCogVideoX
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
def vae_demo(model_path, video_path, dtype, device):
|
def encode_video(model_path, video_path, dtype, device):
|
||||||
"""
|
"""
|
||||||
Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
|
Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
|
||||||
|
|
||||||
@ -32,50 +42,58 @@ def vae_demo(model_path, video_path, dtype, device):
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The encoded video frames.
|
- torch.Tensor: The encoded video frames.
|
||||||
"""
|
"""
|
||||||
# Load the pre-trained model
|
|
||||||
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||||
|
|
||||||
# Load video frames
|
|
||||||
video_reader = imageio.get_reader(video_path, "ffmpeg")
|
video_reader = imageio.get_reader(video_path, "ffmpeg")
|
||||||
frames = []
|
|
||||||
for frame in video_reader:
|
frames = [transforms.ToTensor()(frame) for frame in video_reader]
|
||||||
frames.append(frame)
|
|
||||||
video_reader.close()
|
video_reader.close()
|
||||||
|
|
||||||
# Transform frames to Tensor
|
frames_tensor = torch.stack(frames).to(device).permute(1, 0, 2, 3).unsqueeze(0).to(dtype)
|
||||||
transform = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.ToTensor(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
|
|
||||||
|
|
||||||
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
|
|
||||||
frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device)
|
|
||||||
|
|
||||||
# Run the model with Encoder and Decoder
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(frames_tensor)
|
encoded_frames = model.encode(frames_tensor)[0].sample()
|
||||||
|
return encoded_frames
|
||||||
|
|
||||||
return output
|
|
||||||
|
def decode_video(model_path, encoded_tensor_path, dtype, device):
|
||||||
|
"""
|
||||||
|
Loads a pre-trained AutoencoderKLCogVideoX model and decodes the encoded video frames.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- model_path (str): The path to the pre-trained model.
|
||||||
|
- encoded_tensor_path (str): The path to the encoded tensor file.
|
||||||
|
- dtype (torch.dtype): The data type for computation.
|
||||||
|
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The decoded video frames.
|
||||||
|
"""
|
||||||
|
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||||
|
encoded_frames = torch.load(encoded_tensor_path, weights_only=True).to(device).to(dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
decoded_frames = []
|
||||||
|
for i in range(6): # 6 seconds
|
||||||
|
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||||
|
current_frames = model.decode(encoded_frames[:, :, start_frame:end_frame]).sample
|
||||||
|
decoded_frames.append(current_frames)
|
||||||
|
model.clear_fake_context_parallel_cache()
|
||||||
|
|
||||||
|
decoded_frames = torch.cat(decoded_frames, dim=2)
|
||||||
|
return decoded_frames
|
||||||
|
|
||||||
|
|
||||||
def save_video(tensor, output_path):
|
def save_video(tensor, output_path):
|
||||||
"""
|
"""
|
||||||
Saves the encoded video frames to a video file.
|
Saves the video frames to a video file.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- tensor (torch.Tensor): The encoded video frames.
|
- tensor (torch.Tensor): The video frames tensor.
|
||||||
- output_path (str): The path to save the output video.
|
- output_path (str): The path to save the output video.
|
||||||
"""
|
"""
|
||||||
# Remove batch dimension and permute back to [49, 480, 720, 3]
|
|
||||||
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
|
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
|
||||||
|
frames = np.clip(frames, 0, 1) * 255
|
||||||
|
frames = frames.astype(np.uint8)
|
||||||
|
|
||||||
# Clip values to [0, 1] and convert to uint8
|
|
||||||
frames = np.clip(frames, 0, 1)
|
|
||||||
frames = (frames * 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# Save frames to video
|
|
||||||
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
|
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
writer.append_data(frame)
|
writer.append_data(frame)
|
||||||
@ -83,10 +101,14 @@ def save_video(tensor, output_path):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
|
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
|
||||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
||||||
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
|
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
|
||||||
parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
|
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
|
||||||
|
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
||||||
)
|
)
|
||||||
@ -95,9 +117,21 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set device and dtype
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
||||||
|
|
||||||
output = vae_demo(args.model_path, args.video_path, dtype, device)
|
if args.mode == "encode":
|
||||||
save_video(output, args.output_path)
|
assert args.video_path, "Video path must be provided for encoding."
|
||||||
|
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||||
|
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||||
|
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
|
||||||
|
elif args.mode == "decode":
|
||||||
|
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
|
||||||
|
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
|
||||||
|
save_video(decoded_output, args.output_path)
|
||||||
|
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
|
||||||
|
elif args.mode == "both":
|
||||||
|
assert args.video_path, "Video path must be provided for encoding."
|
||||||
|
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||||
|
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
|
||||||
|
save_video(decoded_output, args.output_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user