mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
vae demo update
This commit is contained in:
parent
032180bb73
commit
18c1669a8e
@ -1,14 +1,24 @@
|
||||
"""
|
||||
This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers.
|
||||
This script is designed to demonstrate how to use the CogVideoX-2b VAE model for video encoding and decoding.
|
||||
It allows you to encode a video into a latent representation, decode it back into a video, or perform both operations sequentially.
|
||||
Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the `{your local diffusers path}` argument to the path of the cloned repository.
|
||||
|
||||
Note:
|
||||
This script requires the `diffusers>=0.30.0` library to be installed.
|
||||
If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon.
|
||||
Cost 71GB of GPU memory for encoding a 6s video at 720p resolution.
|
||||
Command 1: Encoding Video
|
||||
Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-2b VAE model.
|
||||
Memory Usage: ~34GB of GPU memory for encoding.
|
||||
If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder and you can still run the decoding command.
|
||||
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode encode
|
||||
|
||||
Run the script:
|
||||
$ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output
|
||||
Command 2: Decoding Video
|
||||
|
||||
Decodes the latent representation stored in encoded.pt back into a video.
|
||||
Memory Usage: ~19GB of GPU memory for decoding.
|
||||
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --encoded_path ./encoded.pt --mode decode
|
||||
|
||||
Command 3: Encoding and Decoding Video
|
||||
Encodes the video located at ../resources/videos/1.mp4 and then immediately decodes it.
|
||||
Memory Usage: 34GB for encoding + 19GB for decoding (sequentially).
|
||||
$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode both
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -19,7 +29,7 @@ from diffusers import AutoencoderKLCogVideoX
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def vae_demo(model_path, video_path, dtype, device):
|
||||
def encode_video(model_path, video_path, dtype, device):
|
||||
"""
|
||||
Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
|
||||
|
||||
@ -32,50 +42,58 @@ def vae_demo(model_path, video_path, dtype, device):
|
||||
Returns:
|
||||
- torch.Tensor: The encoded video frames.
|
||||
"""
|
||||
# Load the pre-trained model
|
||||
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
|
||||
# Load video frames
|
||||
video_reader = imageio.get_reader(video_path, "ffmpeg")
|
||||
frames = []
|
||||
for frame in video_reader:
|
||||
frames.append(frame)
|
||||
|
||||
frames = [transforms.ToTensor()(frame) for frame in video_reader]
|
||||
video_reader.close()
|
||||
|
||||
# Transform frames to Tensor
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
|
||||
frames_tensor = torch.stack(frames).to(device).permute(1, 0, 2, 3).unsqueeze(0).to(dtype)
|
||||
|
||||
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
|
||||
frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device)
|
||||
|
||||
# Run the model with Encoder and Decoder
|
||||
with torch.no_grad():
|
||||
output = model(frames_tensor)
|
||||
encoded_frames = model.encode(frames_tensor)[0].sample()
|
||||
return encoded_frames
|
||||
|
||||
return output
|
||||
|
||||
def decode_video(model_path, encoded_tensor_path, dtype, device):
|
||||
"""
|
||||
Loads a pre-trained AutoencoderKLCogVideoX model and decodes the encoded video frames.
|
||||
|
||||
Parameters:
|
||||
- model_path (str): The path to the pre-trained model.
|
||||
- encoded_tensor_path (str): The path to the encoded tensor file.
|
||||
- dtype (torch.dtype): The data type for computation.
|
||||
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The decoded video frames.
|
||||
"""
|
||||
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
encoded_frames = torch.load(encoded_tensor_path, weights_only=True).to(device).to(dtype)
|
||||
with torch.no_grad():
|
||||
decoded_frames = []
|
||||
for i in range(6): # 6 seconds
|
||||
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||
current_frames = model.decode(encoded_frames[:, :, start_frame:end_frame]).sample
|
||||
decoded_frames.append(current_frames)
|
||||
model.clear_fake_context_parallel_cache()
|
||||
|
||||
decoded_frames = torch.cat(decoded_frames, dim=2)
|
||||
return decoded_frames
|
||||
|
||||
|
||||
def save_video(tensor, output_path):
|
||||
"""
|
||||
Saves the encoded video frames to a video file.
|
||||
Saves the video frames to a video file.
|
||||
|
||||
Parameters:
|
||||
- tensor (torch.Tensor): The encoded video frames.
|
||||
- tensor (torch.Tensor): The video frames tensor.
|
||||
- output_path (str): The path to save the output video.
|
||||
"""
|
||||
# Remove batch dimension and permute back to [49, 480, 720, 3]
|
||||
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
|
||||
frames = np.clip(frames, 0, 1) * 255
|
||||
frames = frames.astype(np.uint8)
|
||||
|
||||
# Clip values to [0, 1] and convert to uint8
|
||||
frames = np.clip(frames, 0, 1)
|
||||
frames = (frames * 255).astype(np.uint8)
|
||||
|
||||
# Save frames to video
|
||||
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
@ -83,10 +101,14 @@ def save_video(tensor, output_path):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
|
||||
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
||||
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
|
||||
parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
|
||||
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
|
||||
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
|
||||
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
|
||||
parser.add_argument(
|
||||
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
||||
)
|
||||
@ -95,9 +117,21 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device and dtype
|
||||
device = torch.device(args.device)
|
||||
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
||||
|
||||
output = vae_demo(args.model_path, args.video_path, dtype, device)
|
||||
save_video(output, args.output_path)
|
||||
if args.mode == "encode":
|
||||
assert args.video_path, "Video path must be provided for encoding."
|
||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
|
||||
elif args.mode == "decode":
|
||||
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
|
||||
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
|
||||
save_video(decoded_output, args.output_path)
|
||||
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
|
||||
elif args.mode == "both":
|
||||
assert args.video_path, "Video path must be provided for encoding."
|
||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
|
||||
save_video(decoded_output, args.output_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user