diff --git a/inference/cli_vae_demo.py b/inference/cli_vae_demo.py index b133f20..18b9a95 100644 --- a/inference/cli_vae_demo.py +++ b/inference/cli_vae_demo.py @@ -1,14 +1,24 @@ """ -This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers. +This script is designed to demonstrate how to use the CogVideoX-2b VAE model for video encoding and decoding. +It allows you to encode a video into a latent representation, decode it back into a video, or perform both operations sequentially. +Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the `{your local diffusers path}` argument to the path of the cloned repository. -Note: - This script requires the `diffusers>=0.30.0` library to be installed. - If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon. - Cost 71GB of GPU memory for encoding a 6s video at 720p resolution. +Command 1: Encoding Video +Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-2b VAE model. +Memory Usage: ~34GB of GPU memory for encoding. +If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder and you can still run the decoding command. +$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode encode -Run the script: - $ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output +Command 2: Decoding Video +Decodes the latent representation stored in encoded.pt back into a video. +Memory Usage: ~19GB of GPU memory for decoding. +$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --encoded_path ./encoded.pt --mode decode + +Command 3: Encoding and Decoding Video +Encodes the video located at ../resources/videos/1.mp4 and then immediately decodes it. +Memory Usage: 34GB for encoding + 19GB for decoding (sequentially). +$ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode both """ import argparse @@ -19,7 +29,7 @@ from diffusers import AutoencoderKLCogVideoX from torchvision import transforms -def vae_demo(model_path, video_path, dtype, device): +def encode_video(model_path, video_path, dtype, device): """ Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames. @@ -32,50 +42,58 @@ def vae_demo(model_path, video_path, dtype, device): Returns: - torch.Tensor: The encoded video frames. """ - # Load the pre-trained model model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device) - - # Load video frames video_reader = imageio.get_reader(video_path, "ffmpeg") - frames = [] - for frame in video_reader: - frames.append(frame) + + frames = [transforms.ToTensor()(frame) for frame in video_reader] video_reader.close() - # Transform frames to Tensor - transform = transforms.Compose( - [ - transforms.ToTensor(), - ] - ) - frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device) + frames_tensor = torch.stack(frames).to(device).permute(1, 0, 2, 3).unsqueeze(0).to(dtype) - # Add batch dimension and reshape to [1, 3, 49, 480, 720] - frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device) - - # Run the model with Encoder and Decoder with torch.no_grad(): - output = model(frames_tensor) + encoded_frames = model.encode(frames_tensor)[0].sample() + return encoded_frames - return output + +def decode_video(model_path, encoded_tensor_path, dtype, device): + """ + Loads a pre-trained AutoencoderKLCogVideoX model and decodes the encoded video frames. + + Parameters: + - model_path (str): The path to the pre-trained model. + - encoded_tensor_path (str): The path to the encoded tensor file. + - dtype (torch.dtype): The data type for computation. + - device (str): The device to use for computation (e.g., "cuda" or "cpu"). + + Returns: + - torch.Tensor: The decoded video frames. + """ + model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device) + encoded_frames = torch.load(encoded_tensor_path, weights_only=True).to(device).to(dtype) + with torch.no_grad(): + decoded_frames = [] + for i in range(6): # 6 seconds + start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) + current_frames = model.decode(encoded_frames[:, :, start_frame:end_frame]).sample + decoded_frames.append(current_frames) + model.clear_fake_context_parallel_cache() + + decoded_frames = torch.cat(decoded_frames, dim=2) + return decoded_frames def save_video(tensor, output_path): """ - Saves the encoded video frames to a video file. + Saves the video frames to a video file. Parameters: - - tensor (torch.Tensor): The encoded video frames. + - tensor (torch.Tensor): The video frames tensor. - output_path (str): The path to save the output video. """ - # Remove batch dimension and permute back to [49, 480, 720, 3] frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy() + frames = np.clip(frames, 0, 1) * 255 + frames = frames.astype(np.uint8) - # Clip values to [0, 1] and convert to uint8 - frames = np.clip(frames, 0, 1) - frames = (frames * 255).astype(np.uint8) - - # Save frames to video writer = imageio.get_writer(output_path + "/output.mp4", fps=30) for frame in frames: writer.append_data(frame) @@ -83,10 +101,14 @@ def save_video(tensor, output_path): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers") + parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo") parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model") - parser.add_argument("--video_path", type=str, required=True, help="The path to the video file") - parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video") + parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)") + parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)") + parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file") + parser.add_argument( + "--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both" + ) parser.add_argument( "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')" ) @@ -95,9 +117,21 @@ if __name__ == "__main__": ) args = parser.parse_args() - # Set device and dtype device = torch.device(args.device) dtype = torch.float16 if args.dtype == "float16" else torch.float32 - output = vae_demo(args.model_path, args.video_path, dtype, device) - save_video(output, args.output_path) + if args.mode == "encode": + assert args.video_path, "Video path must be provided for encoding." + encoded_output = encode_video(args.model_path, args.video_path, dtype, device) + torch.save(encoded_output, args.output_path + "/encoded.pt") + print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt") + elif args.mode == "decode": + assert args.encoded_path, "Encoded tensor path must be provided for decoding." + decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device) + save_video(decoded_output, args.output_path) + print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4") + elif args.mode == "both": + assert args.video_path, "Video path must be provided for encoding." + encoded_output = encode_video(args.model_path, args.video_path, dtype, device) + decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device) + save_video(decoded_output, args.output_path)