diff --git a/inference/cli_demo.py b/inference/cli_demo.py index 7d56f17..1910a78 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -1,29 +1,46 @@ """ -This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline. +This script demonstrates how to generate a video using the CogVideoX model with the Hugging Face `diffusers` pipeline. +The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v), +and video-to-video (v2v), depending on the input data and different weight. -Note: - This script requires the `diffusers>=0.30.0` library to be installed, after `diffusers 0.31.0` release, - need to update. +- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b +- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b +- image-to-video: THUDM/CogVideoX-5b-I2V -Run the script: - $ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b +Running the Script: +To run the script, use the following command with appropriate arguments: +```bash +$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v" +``` + +Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. """ import argparse +from typing import Literal + import torch -from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from diffusers.utils import export_to_video +from diffusers import (CogVideoXPipeline, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXVideoToVideoPipeline) + +from diffusers.utils import export_to_video, load_image, load_video def generate_video( - prompt: str, - model_path: str, - output_path: str = "./output.mp4", - num_inference_steps: int = 50, - guidance_scale: float = 6.0, - num_videos_per_prompt: int = 1, - dtype: torch.dtype = torch.bfloat16, + prompt: str, + model_path: str, + output_path: str = "./output.mp4", + image_or_video_path: str = "", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: torch.dtype = torch.bfloat16, + generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video + seed: int = 42, ): """ Generates a video based on the given prompt and saves it to the specified path. @@ -36,14 +53,25 @@ def generate_video( - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. - num_videos_per_prompt (int): Number of videos to generate per prompt. - dtype (torch.dtype): The data type for computation (default is torch.bfloat16). - + - generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v'). + - seed (int): The seed for reproducibility. """ # 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16). # add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload() # function to use Multi GPUs. - pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype) + image = None + video = None + + if generate_type == "i2v": + pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) + image = load_image(image=image_or_video_path) + elif generate_type == "t2v": + pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype) + else: + pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) + video = load_video(image_or_video_path) # 2. Set Scheduler. # Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`. @@ -51,63 +79,83 @@ def generate_video( # pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") - # 3. Enable CPU offload for the model, enable tiling. + # 3. Enable CPU offload for the model. # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference - pipe.enable_sequential_cpu_offload() + # and enable to("cuda") + + # pipe.enable_sequential_cpu_offload() + pipe.to("cuda") pipe.vae.enable_slicing() pipe.vae.enable_tiling() # 4. Generate the video frames based on the prompt. # `num_frames` is the Number of frames to generate. - # This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame. - # for diffusers `0.30.1` and after version, this should be 49. - - video = pipe( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt - num_inference_steps=num_inference_steps, # Number of inference steps - num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after. - use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False - guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance, can set to 7 for DPM scheduler - generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility - ).frames[0] - + # This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame and 49 frames. + if generate_type == "i2v": + video_generate = pipe( + prompt=prompt, + image=image, # The path of the image to be used as the background of the video + num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt + num_inference_steps=num_inference_steps, # Number of inference steps + num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after. + use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility + ).frames[0] + elif generate_type == "t2v": + video_generate = pipe( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), + ).frames[0] + else: + video_generate = pipe( + prompt=prompt, + video=video, # The path of the video to be used as the background of the video + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility + ).frames[0] # 5. Export the generated frames to a video file. fps must be 8 for original video. - export_to_video(video, output_path, fps=8) + export_to_video(video_generate, output_path, fps=8) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") - parser.add_argument( - "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" - ) - parser.add_argument( - "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" - ) - parser.add_argument( - "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process" - ) + parser.add_argument("--image_or_video_path", type=str, default=None, + help="The path of the image to be used as the background of the video") + parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", + help="The path of the pre-trained model to be used") + parser.add_argument("--output_path", type=str, default="./output.mp4", + help="The path where the generated video will be saved") parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") + parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of steps for the inference process") parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") - parser.add_argument( - "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" - ) + parser.add_argument("--generate_type", type=str, default="t2v", + help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')") + parser.add_argument("--dtype", type=str, default="bfloat16", + help="The data type for computation (e.g., 'float16' or 'bfloat16')") + parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") args = parser.parse_args() - - # Convert dtype argument to torch.dtype. - # For CogVideoX-2B model, use torch.float16. - # For CogVideoX-5B model, use torch.bfloat16. dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 - - # main function to generate video. generate_video( prompt=args.prompt, model_path=args.model_path, + image_or_video_path=args.image_or_video_path, output_path=args.output_path, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, num_videos_per_prompt=args.num_videos_per_prompt, dtype=dtype, + generate_type=args.generate_type, + seed=args.seed, ) diff --git a/sat/configs/cogvideox_5b.yaml b/sat/configs/cogvideox_5b.yaml index 6805dec..22ba694 100644 --- a/sat/configs/cogvideox_5b.yaml +++ b/sat/configs/cogvideox_5b.yaml @@ -1,5 +1,5 @@ model: - scale_factor: 0.7 # different from cogvideox_2b_infer.yaml + scale_factor: 0.7 disable_first_stage_autocast: true log_keys: - txt diff --git a/sat/configs/cogvideox_5b_i2v.yaml b/sat/configs/cogvideox_5b_i2v.yaml new file mode 100644 index 0000000..4baf963 --- /dev/null +++ b/sat/configs/cogvideox_5b_i2v.yaml @@ -0,0 +1,159 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + latent_input: false + noised_image_input: true + noised_image_dropout: 0.05 + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 + patch_size: 2 + in_channels: 32 #different from cogvideox_5b_infer.yaml + out_channels: 16 + hidden_size: 3072 + adm_in_channels: 256 + num_attention_heads: 48 + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin + params: + learnable_pos_embed: True + hidden_size_head: 64 + text_length: 226 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt" + ignore_keys: ['loss'] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + fixed_frames: 0 + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + fixed_frames: 0 + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/sat/configs/cogvideox_5b_i2v_lora.yaml b/sat/configs/cogvideox_5b_i2v_lora.yaml new file mode 100644 index 0000000..e36aee7 --- /dev/null +++ b/sat/configs/cogvideox_5b_i2v_lora.yaml @@ -0,0 +1,165 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + latent_input: false + noised_image_input: true + noised_image_dropout: 0.05 + not_trainable_prefixes: ['all'] ## Using Lora + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 + patch_size: 2 + in_channels: 32 + out_channels: 16 + hidden_size: 3072 + adm_in_channels: 256 + num_attention_heads: 48 + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin + params: + learnable_pos_embed: True + hidden_size_head: 64 + text_length: 226 + + lora_config: + target: sat.model.finetune.lora2.LoraMixin + params: + r: 256 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + fixed_frames: 0 + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + fixed_frames: 0 + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/sat/configs/inference.yaml b/sat/configs/inference.yaml index a745639..a93bb99 100644 --- a/sat/configs/inference.yaml +++ b/sat/configs/inference.yaml @@ -1,15 +1,16 @@ args: + image2video: False # True for image2video, False for text2video latent_channels: 16 mode: inference load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter - batch_size: 1 input_type: txt input_file: configs/test.txt + sampling_image_size: [480, 720] sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_fps: 8 - fp16: True # For CogVideoX-2B -# bf16: True # For CogVideoX-5B +# fp16: True # For CogVideoX-2B + bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V output_dir: outputs/ force_inference: True \ No newline at end of file diff --git a/sat/configs/sft.yaml b/sat/configs/sft.yaml index bbdf1a7..971c521 100644 --- a/sat/configs/sft.yaml +++ b/sat/configs/sft.yaml @@ -1,15 +1,15 @@ args: - checkpoint_activations: True ## using gradient checkpointing + checkpoint_activations: True # using gradient checkpointing model_parallel_size: 1 experiment_name: lora-disney mode: finetune - load: "cogvideox-2b-sat/transformer" + load: "{your CogVideoX SAT folder}/transformer" no_load_rng: True train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough eval_iters: 1 eval_interval: 100 eval_batch_size: 1 - save: ckpts_2b_lora + save: ckpts_5b_lora save_interval: 500 log_interval: 20 train_data: [ "disney" ] # Train data path @@ -28,7 +28,7 @@ data: skip_frms_num: 3. deepspeed: - # Minimun for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs + # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs train_micro_batch_size_per_gpu: 2 gradient_accumulation_steps: 1 steps_per_print: 50 @@ -44,9 +44,9 @@ deepspeed: load_from_fp32_weights: false zero_allow_untested_optimizer: true bf16: - enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True + enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True fp16: - enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False + enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False loss_scale: 0 loss_scale_window: 400 hysteresis: 2 @@ -55,7 +55,7 @@ deepspeed: optimizer: type: sat.ops.FusedEmaAdam params: - lr: 0.001 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT + lr: 0.00001 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT betas: [ 0.9, 0.95 ] eps: 1e-8 weight_decay: 1e-4 diff --git a/sat/data_video.py b/sat/data_video.py index 25d17ee..b572d83 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -362,7 +362,7 @@ class SFTDataset(Dataset): skip_frms_num: ignore the first and the last xx frames, avoiding transitions. """ super(SFTDataset, self).__init__() - + self.video_size = video_size self.fps = fps self.max_num_frames = max_num_frames @@ -385,7 +385,6 @@ class SFTDataset(Dataset): self.captions.append(caption) def __getitem__(self, index): - decord.bridge.set_bridge("torch") video_path = self.video_paths[index] @@ -411,9 +410,7 @@ class SFTDataset(Dataset): indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None - tensor_frms = ( - torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - ) + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] else: @@ -426,15 +423,11 @@ class SFTDataset(Dataset): start = int(self.skip_frms_num) end = int(ori_vlen - self.skip_frms_num) - num_frames = nearest_smaller_4k_plus_1( - end - start - ) # 3D VAE requires the number of frames to be 4k+1 + num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1 end = int(start + num_frames) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None - tensor_frms = ( - torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - ) + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = pad_last_frame( tensor_frms, self.max_num_frames diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 6bee4ce..8329e9d 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -1,3 +1,5 @@ +import random + import math from typing import Any, Dict, List, Tuple, Union from omegaconf import ListConfig @@ -130,6 +132,13 @@ class SATVideoDiffusionEngine(nn.Module): loss_dict = {"loss": loss_mean} return loss_mean, loss_dict + def add_noise_to_first_frame(self, image): + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(self.device) + sigma = torch.exp(sigma).to(image.dtype) + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image = image + image_noise + return image + def shared_step(self, batch: Dict) -> Any: x = self.get_input(batch) if self.lr_scale is not None: @@ -139,8 +148,22 @@ class SATVideoDiffusionEngine(nn.Module): batch["lr_input"] = lr_z x = x.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_input: + image = x[:, :, 0:1] + image = self.add_noise_to_first_frame(image) + image = self.encode_first_stage(image, batch) + x = self.encode_first_stage(x, batch) x = x.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_input: + image = image.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_all_concat: + image = image.repeat(1, x.shape[1], 1, 1, 1) + else: + image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1) + if random.random() < self.noised_image_dropout: + image = torch.zeros_like(image) + batch["concat_images"] = image gc.collect() torch.cuda.empty_cache() @@ -300,8 +323,7 @@ class SATVideoDiffusionEngine(nn.Module): if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) - samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w - samples = samples.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_input: image = x[:, :, 0:1] image = self.add_noise_to_first_frame(image) @@ -320,6 +342,8 @@ class SATVideoDiffusionEngine(nn.Module): samples = samples.permute(0, 2, 1, 3, 4).contiguous() log["samples"] = samples else: + samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w + samples = samples.permute(0, 2, 1, 3, 4).contiguous() if only_log_video_latents: latents = 1.0 / self.scale_factor * samples log["latents"] = latents diff --git a/sat/finetune_multi_gpus.sh b/sat/finetune_multi_gpus.sh index ef56701..a9a8ad2 100644 --- a/sat/finetune_multi_gpus.sh +++ b/sat/finetune_multi_gpus.sh @@ -1,8 +1,8 @@ #! /bin/bash -echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" -run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_5b_i2v_lora.yaml configs/sft.yaml --seed $RANDOM" echo ${run_cmd} eval ${run_cmd} diff --git a/sat/sgm/models/autoencoder.py b/sat/sgm/models/autoencoder.py index 9ae44d0..0b21318 100644 --- a/sat/sgm/models/autoencoder.py +++ b/sat/sgm/models/autoencoder.py @@ -547,84 +547,3 @@ class VideoAutoencodingEngine(AutoencodingEngine): print("Missing keys: ", missing_keys) print("Unexpected keys: ", unexpected_keys) print(f"Restored from {path}") - - -class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): - def __init__( - self, - cp_size=0, - *args, - **kwargs, - ): - self.cp_size = cp_size - return super().__init__(*args, **kwargs) - - def encode( - self, - x: torch.Tensor, - return_reg_log: bool = False, - unregularized: bool = False, - input_cp: bool = False, - output_cp: bool = False, - use_cp: bool = True, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - if self.cp_size <= 1: - use_cp = False - if self.cp_size > 0 and use_cp and not input_cp: - if not is_context_parallel_initialized: - initialize_context_parallel(self.cp_size) - - global_src_rank = get_context_parallel_group_rank() * self.cp_size - torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group()) - - x = _conv_split(x, dim=2, kernel_size=1) - - if return_reg_log: - z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) - else: - z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) - - if self.cp_size > 0 and use_cp and not output_cp: - z = _conv_gather(z, dim=2, kernel_size=1) - - if return_reg_log: - return z, reg_log - return z - - def decode( - self, - z: torch.Tensor, - input_cp: bool = False, - output_cp: bool = False, - use_cp: bool = True, - **kwargs, - ): - if self.cp_size <= 1: - use_cp = False - if self.cp_size > 0 and use_cp and not input_cp: - if not is_context_parallel_initialized: - initialize_context_parallel(self.cp_size) - - global_src_rank = get_context_parallel_group_rank() * self.cp_size - torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) - - z = _conv_split(z, dim=2, kernel_size=1) - - x = super().decode(z, use_cp=use_cp, **kwargs) - - if self.cp_size > 0 and use_cp and not output_cp: - x = _conv_gather(x, dim=2, kernel_size=1) - - return x - - def forward( - self, - x: torch.Tensor, - input_cp: bool = False, - latent_cp: bool = False, - output_cp: bool = False, - **additional_decode_kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp) - dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs) - return z, dec, reg_log diff --git a/sat/sgm/modules/cp_enc_dec.py b/sat/sgm/modules/cp_enc_dec.py index 469595d..931e657 100644 --- a/sat/sgm/modules/cp_enc_dec.py +++ b/sat/sgm/modules/cp_enc_dec.py @@ -6,7 +6,6 @@ from ..util import ( get_context_parallel_group, get_context_parallel_rank, get_context_parallel_world_size, - ) _USE_CP = True @@ -179,4 +178,4 @@ def _conv_gather(input_, dim, kernel_size): # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) - return output \ No newline at end of file + return output diff --git a/sat/sgm/modules/diffusionmodules/loss.py b/sat/sgm/modules/diffusionmodules/loss.py index 48285da..66916c1 100644 --- a/sat/sgm/modules/diffusionmodules/loss.py +++ b/sat/sgm/modules/diffusionmodules/loss.py @@ -100,8 +100,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss): ) if "concat_images" in batch.keys(): - additional_model_inputs["concat_images"] = batch["concat_images"] + cond["concat"] = batch["concat_images"] + # [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx']) model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs) w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred @@ -117,11 +118,3 @@ class VideoDiffusionLoss(StandardDiffusionLoss): elif self.type == "lpips": loss = self.lpips(model_output, target).reshape(-1) return loss - - -def get_3d_position_ids(frame_len, h, w): - i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w) - j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w) - k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w) - position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3) - return position_ids diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py index 6ef8e7f..183be62 100644 --- a/tools/convert_weight_sat2hf.py +++ b/tools/convert_weight_sat2hf.py @@ -1,11 +1,7 @@ """ -This script demonstrates how to convert and generate video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline. - -Note: - This script requires the `diffusers>=0.30.1` library to be installed. - -Run the script: - $ python convert_and_generate.py --transformer_ckpt_path --vae_ckpt_path --output_path --text_encoder_path +This script demonstrates how to convert and generate video from a text prompt +using CogVideoX with 🤗Huggingface Diffusers Pipeline. +This script requires the `diffusers>=0.30.2` library to be installed. Functions: - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place. @@ -27,7 +23,13 @@ from typing import Any, Dict import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): @@ -101,6 +103,7 @@ TRANSFORMER_KEYS_RENAME_DICT = { "mixins.final_layer.norm_final": "norm_out.norm", "mixins.final_layer.linear": "proj_out", "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", + "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V } TRANSFORMER_SPECIAL_KEYS_REMAP = { @@ -154,15 +157,18 @@ def convert_transformer( num_layers: int, num_attention_heads: int, use_rotary_positional_embeddings: bool, + i2v: bool, dtype: torch.dtype, ): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) transformer = CogVideoXTransformer3DModel( + in_channels=32 if i2v else 16, num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, + use_learned_positional_embeddings=i2v, ).to(dtype=dtype) for key in list(original_state_dict.keys()): @@ -176,7 +182,6 @@ def convert_transformer( if special_key not in key: continue handler_fn_inplace(key, original_state_dict) - transformer.load_state_dict(original_state_dict, strict=True) return transformer @@ -204,8 +209,7 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" - ) + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint") parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") @@ -228,6 +232,7 @@ def get_args(): parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") + parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") return parser.parse_args() @@ -248,6 +253,7 @@ if __name__ == "__main__": args.num_layers, args.num_attention_heads, args.use_rotary_positional_embeddings, + args.i2v, dtype, ) if args.vae_ckpt_path is not None: @@ -256,8 +262,7 @@ if __name__ == "__main__": text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) - - # Apparently, the conversion does not work any more without this :shrug: + # Apparently, the conversion does not work anymore without this :shrug: for param in text_encoder.parameters(): param.data = param.data.contiguous() @@ -275,9 +280,17 @@ if __name__ == "__main__": "timestep_spacing": "trailing", } ) + if args.i2v: + pipeline_cls = CogVideoXImageToVideoPipeline + else: + pipeline_cls = CogVideoXPipeline - pipe = CogVideoXPipeline( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + pipe = pipeline_cls( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, ) if args.fp16: