From 300fc75c4992b00b03369a1d71974cdd69d26d84 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 14 Sep 2024 15:35:29 +0800 Subject: [PATCH] diffusers converter update --- inference/cli_demo.py | 150 ++++++++++++++++++++++----------- tools/convert_weight_sat2hf.py | 43 ++++++---- 2 files changed, 127 insertions(+), 66 deletions(-) diff --git a/inference/cli_demo.py b/inference/cli_demo.py index 7d56f17..1910a78 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -1,29 +1,46 @@ """ -This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline. +This script demonstrates how to generate a video using the CogVideoX model with the Hugging Face `diffusers` pipeline. +The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v), +and video-to-video (v2v), depending on the input data and different weight. -Note: - This script requires the `diffusers>=0.30.0` library to be installed, after `diffusers 0.31.0` release, - need to update. +- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b +- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b +- image-to-video: THUDM/CogVideoX-5b-I2V -Run the script: - $ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b +Running the Script: +To run the script, use the following command with appropriate arguments: +```bash +$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v" +``` + +Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. """ import argparse +from typing import Literal + import torch -from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from diffusers.utils import export_to_video +from diffusers import (CogVideoXPipeline, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXVideoToVideoPipeline) + +from diffusers.utils import export_to_video, load_image, load_video def generate_video( - prompt: str, - model_path: str, - output_path: str = "./output.mp4", - num_inference_steps: int = 50, - guidance_scale: float = 6.0, - num_videos_per_prompt: int = 1, - dtype: torch.dtype = torch.bfloat16, + prompt: str, + model_path: str, + output_path: str = "./output.mp4", + image_or_video_path: str = "", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: torch.dtype = torch.bfloat16, + generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video + seed: int = 42, ): """ Generates a video based on the given prompt and saves it to the specified path. @@ -36,14 +53,25 @@ def generate_video( - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. - num_videos_per_prompt (int): Number of videos to generate per prompt. - dtype (torch.dtype): The data type for computation (default is torch.bfloat16). - + - generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v'). + - seed (int): The seed for reproducibility. """ # 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16). # add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload() # function to use Multi GPUs. - pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype) + image = None + video = None + + if generate_type == "i2v": + pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) + image = load_image(image=image_or_video_path) + elif generate_type == "t2v": + pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype) + else: + pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) + video = load_video(image_or_video_path) # 2. Set Scheduler. # Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`. @@ -51,63 +79,83 @@ def generate_video( # pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") - # 3. Enable CPU offload for the model, enable tiling. + # 3. Enable CPU offload for the model. # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference - pipe.enable_sequential_cpu_offload() + # and enable to("cuda") + + # pipe.enable_sequential_cpu_offload() + pipe.to("cuda") pipe.vae.enable_slicing() pipe.vae.enable_tiling() # 4. Generate the video frames based on the prompt. # `num_frames` is the Number of frames to generate. - # This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame. - # for diffusers `0.30.1` and after version, this should be 49. - - video = pipe( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt - num_inference_steps=num_inference_steps, # Number of inference steps - num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after. - use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False - guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance, can set to 7 for DPM scheduler - generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility - ).frames[0] - + # This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame and 49 frames. + if generate_type == "i2v": + video_generate = pipe( + prompt=prompt, + image=image, # The path of the image to be used as the background of the video + num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt + num_inference_steps=num_inference_steps, # Number of inference steps + num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after. + use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility + ).frames[0] + elif generate_type == "t2v": + video_generate = pipe( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + ).frames[0] + else: + video_generate = pipe( + prompt=prompt, + video=video, # The path of the video to be used as the background of the video + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility + ).frames[0] # 5. Export the generated frames to a video file. fps must be 8 for original video. - export_to_video(video, output_path, fps=8) + export_to_video(video_generate, output_path, fps=8) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") - parser.add_argument( - "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" - ) - parser.add_argument( - "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" - ) - parser.add_argument( - "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process" - ) + parser.add_argument("--image_or_video_path", type=str, default=None, + help="The path of the image to be used as the background of the video") + parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", + help="The path of the pre-trained model to be used") + parser.add_argument("--output_path", type=str, default="./output.mp4", + help="The path where the generated video will be saved") parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") + parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of steps for the inference process") parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") - parser.add_argument( - "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" - ) + parser.add_argument("--generate_type", type=str, default="t2v", + help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')") + parser.add_argument("--dtype", type=str, default="bfloat16", + help="The data type for computation (e.g., 'float16' or 'bfloat16')") + parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") args = parser.parse_args() - - # Convert dtype argument to torch.dtype. - # For CogVideoX-2B model, use torch.float16. - # For CogVideoX-5B model, use torch.bfloat16. dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 - - # main function to generate video. generate_video( prompt=args.prompt, model_path=args.model_path, + image_or_video_path=args.image_or_video_path, output_path=args.output_path, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, num_videos_per_prompt=args.num_videos_per_prompt, dtype=dtype, + generate_type=args.generate_type, + seed=args.seed, ) diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py index 6ef8e7f..183be62 100644 --- a/tools/convert_weight_sat2hf.py +++ b/tools/convert_weight_sat2hf.py @@ -1,11 +1,7 @@ """ -This script demonstrates how to convert and generate video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline. - -Note: - This script requires the `diffusers>=0.30.1` library to be installed. - -Run the script: - $ python convert_and_generate.py --transformer_ckpt_path --vae_ckpt_path --output_path --text_encoder_path +This script demonstrates how to convert and generate video from a text prompt +using CogVideoX with 🤗Huggingface Diffusers Pipeline. +This script requires the `diffusers>=0.30.2` library to be installed. Functions: - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place. @@ -27,7 +23,13 @@ from typing import Any, Dict import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): @@ -101,6 +103,7 @@ TRANSFORMER_KEYS_RENAME_DICT = { "mixins.final_layer.norm_final": "norm_out.norm", "mixins.final_layer.linear": "proj_out", "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", + "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V } TRANSFORMER_SPECIAL_KEYS_REMAP = { @@ -154,15 +157,18 @@ def convert_transformer( num_layers: int, num_attention_heads: int, use_rotary_positional_embeddings: bool, + i2v: bool, dtype: torch.dtype, ): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) transformer = CogVideoXTransformer3DModel( + in_channels=32 if i2v else 16, num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, + use_learned_positional_embeddings=i2v, ).to(dtype=dtype) for key in list(original_state_dict.keys()): @@ -176,7 +182,6 @@ def convert_transformer( if special_key not in key: continue handler_fn_inplace(key, original_state_dict) - transformer.load_state_dict(original_state_dict, strict=True) return transformer @@ -204,8 +209,7 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" - ) + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint") parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") @@ -228,6 +232,7 @@ def get_args(): parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") + parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") return parser.parse_args() @@ -248,6 +253,7 @@ if __name__ == "__main__": args.num_layers, args.num_attention_heads, args.use_rotary_positional_embeddings, + args.i2v, dtype, ) if args.vae_ckpt_path is not None: @@ -256,8 +262,7 @@ if __name__ == "__main__": text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) - - # Apparently, the conversion does not work any more without this :shrug: + # Apparently, the conversion does not work anymore without this :shrug: for param in text_encoder.parameters(): param.data = param.data.contiguous() @@ -275,9 +280,17 @@ if __name__ == "__main__": "timestep_spacing": "trailing", } ) + if args.i2v: + pipeline_cls = CogVideoXImageToVideoPipeline + else: + pipeline_cls = CogVideoXPipeline - pipe = CogVideoXPipeline( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + pipe = pipeline_cls( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, ) if args.fp16: