mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
commit
01f19dad11
@ -1,29 +1,46 @@
|
|||||||
"""
|
"""
|
||||||
This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
This script demonstrates how to generate a video using the CogVideoX model with the Hugging Face `diffusers` pipeline.
|
||||||
|
The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v),
|
||||||
|
and video-to-video (v2v), depending on the input data and different weight.
|
||||||
|
|
||||||
Note:
|
- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
|
||||||
This script requires the `diffusers>=0.30.0` library to be installed, after `diffusers 0.31.0` release,
|
- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
|
||||||
need to update.
|
- image-to-video: THUDM/CogVideoX-5b-I2V
|
||||||
|
|
||||||
Run the script:
|
Running the Script:
|
||||||
$ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
|
To run the script, use the following command with appropriate arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v"
|
||||||
|
```
|
||||||
|
|
||||||
|
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
from diffusers import (CogVideoXPipeline,
|
||||||
from diffusers.utils import export_to_video
|
CogVideoXDDIMScheduler,
|
||||||
|
CogVideoXDPMScheduler,
|
||||||
|
CogVideoXImageToVideoPipeline,
|
||||||
|
CogVideoXVideoToVideoPipeline)
|
||||||
|
|
||||||
|
from diffusers.utils import export_to_video, load_image, load_video
|
||||||
|
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_path: str = "./output.mp4",
|
output_path: str = "./output.mp4",
|
||||||
num_inference_steps: int = 50,
|
image_or_video_path: str = "",
|
||||||
guidance_scale: float = 6.0,
|
num_inference_steps: int = 50,
|
||||||
num_videos_per_prompt: int = 1,
|
guidance_scale: float = 6.0,
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
num_videos_per_prompt: int = 1,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
||||||
|
seed: int = 42,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a video based on the given prompt and saves it to the specified path.
|
Generates a video based on the given prompt and saves it to the specified path.
|
||||||
@ -36,14 +53,25 @@ def generate_video(
|
|||||||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||||
|
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').
|
||||||
|
- seed (int): The seed for reproducibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
||||||
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
|
# add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
|
||||||
# function to use Multi GPUs.
|
# function to use Multi GPUs.
|
||||||
|
|
||||||
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
image = None
|
||||||
|
video = None
|
||||||
|
|
||||||
|
if generate_type == "i2v":
|
||||||
|
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||||
|
image = load_image(image=image_or_video_path)
|
||||||
|
elif generate_type == "t2v":
|
||||||
|
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||||
|
else:
|
||||||
|
pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||||
|
video = load_video(image_or_video_path)
|
||||||
|
|
||||||
# 2. Set Scheduler.
|
# 2. Set Scheduler.
|
||||||
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
|
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
|
||||||
@ -51,63 +79,83 @@ def generate_video(
|
|||||||
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||||
|
|
||||||
# 3. Enable CPU offload for the model, enable tiling.
|
# 3. Enable CPU offload for the model.
|
||||||
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
||||||
pipe.enable_sequential_cpu_offload()
|
# and enable to("cuda")
|
||||||
|
|
||||||
|
# pipe.enable_sequential_cpu_offload()
|
||||||
|
pipe.to("cuda")
|
||||||
pipe.vae.enable_slicing()
|
pipe.vae.enable_slicing()
|
||||||
pipe.vae.enable_tiling()
|
pipe.vae.enable_tiling()
|
||||||
|
|
||||||
# 4. Generate the video frames based on the prompt.
|
# 4. Generate the video frames based on the prompt.
|
||||||
# `num_frames` is the Number of frames to generate.
|
# `num_frames` is the Number of frames to generate.
|
||||||
# This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame.
|
# This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame and 49 frames.
|
||||||
# for diffusers `0.30.1` and after version, this should be 49.
|
if generate_type == "i2v":
|
||||||
|
video_generate = pipe(
|
||||||
video = pipe(
|
prompt=prompt,
|
||||||
prompt=prompt,
|
image=image, # The path of the image to be used as the background of the video
|
||||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||||
num_inference_steps=num_inference_steps, # Number of inference steps
|
num_inference_steps=num_inference_steps, # Number of inference steps
|
||||||
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after.
|
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.31.0` and after.
|
||||||
use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False
|
use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False
|
||||||
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance, can set to 7 for DPM scheduler
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator().manual_seed(42), # Set the seed for reproducibility
|
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||||
).frames[0]
|
).frames[0]
|
||||||
|
elif generate_type == "t2v":
|
||||||
|
video_generate = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_frames=49,
|
||||||
|
use_dynamic_cfg=True,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
generator=torch.Generator().manual_seed(seed),
|
||||||
|
).frames[0]
|
||||||
|
else:
|
||||||
|
video_generate = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
video=video, # The path of the video to be used as the background of the video
|
||||||
|
num_videos_per_prompt=num_videos_per_prompt,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_frames=49,
|
||||||
|
use_dynamic_cfg=True,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||||
|
).frames[0]
|
||||||
# 5. Export the generated frames to a video file. fps must be 8 for original video.
|
# 5. Export the generated frames to a video file. fps must be 8 for original video.
|
||||||
export_to_video(video, output_path, fps=8)
|
export_to_video(video_generate, output_path, fps=8)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||||
parser.add_argument(
|
parser.add_argument("--image_or_video_path", type=str, default=None,
|
||||||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
help="The path of the image to be used as the background of the video")
|
||||||
)
|
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b",
|
||||||
parser.add_argument(
|
help="The path of the pre-trained model to be used")
|
||||||
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
parser.add_argument("--output_path", type=str, default="./output.mp4",
|
||||||
)
|
help="The path where the generated video will be saved")
|
||||||
parser.add_argument(
|
|
||||||
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
|
||||||
)
|
|
||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||||
|
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of steps for the inference process")
|
||||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||||
parser.add_argument(
|
parser.add_argument("--generate_type", type=str, default="t2v",
|
||||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')")
|
||||||
)
|
parser.add_argument("--dtype", type=str, default="bfloat16",
|
||||||
|
help="The data type for computation (e.g., 'float16' or 'bfloat16')")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Convert dtype argument to torch.dtype.
|
|
||||||
# For CogVideoX-2B model, use torch.float16.
|
|
||||||
# For CogVideoX-5B model, use torch.bfloat16.
|
|
||||||
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
||||||
|
|
||||||
# main function to generate video.
|
|
||||||
generate_video(
|
generate_video(
|
||||||
prompt=args.prompt,
|
prompt=args.prompt,
|
||||||
model_path=args.model_path,
|
model_path=args.model_path,
|
||||||
|
image_or_video_path=args.image_or_video_path,
|
||||||
output_path=args.output_path,
|
output_path=args.output_path,
|
||||||
num_inference_steps=args.num_inference_steps,
|
num_inference_steps=args.num_inference_steps,
|
||||||
guidance_scale=args.guidance_scale,
|
guidance_scale=args.guidance_scale,
|
||||||
num_videos_per_prompt=args.num_videos_per_prompt,
|
num_videos_per_prompt=args.num_videos_per_prompt,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
generate_type=args.generate_type,
|
||||||
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
scale_factor: 0.7 # different from cogvideox_2b_infer.yaml
|
scale_factor: 0.7
|
||||||
disable_first_stage_autocast: true
|
disable_first_stage_autocast: true
|
||||||
log_keys:
|
log_keys:
|
||||||
- txt
|
- txt
|
||||||
|
159
sat/configs/cogvideox_5b_i2v.yaml
Normal file
159
sat/configs/cogvideox_5b_i2v.yaml
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
model:
|
||||||
|
scale_factor: 0.7
|
||||||
|
disable_first_stage_autocast: true
|
||||||
|
latent_input: false
|
||||||
|
noised_image_input: true
|
||||||
|
noised_image_dropout: 0.05
|
||||||
|
log_keys:
|
||||||
|
- txt
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
|
params:
|
||||||
|
num_idx: 1000
|
||||||
|
quantize_c_noise: False
|
||||||
|
|
||||||
|
weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||||
|
params:
|
||||||
|
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: dit_video_concat.DiffusionTransformer
|
||||||
|
params:
|
||||||
|
time_embed_dim: 512
|
||||||
|
elementwise_affine: True
|
||||||
|
num_frames: 49
|
||||||
|
time_compressed_rate: 4
|
||||||
|
latent_width: 90
|
||||||
|
latent_height: 60
|
||||||
|
num_layers: 42
|
||||||
|
patch_size: 2
|
||||||
|
in_channels: 32 #different from cogvideox_5b_infer.yaml
|
||||||
|
out_channels: 16
|
||||||
|
hidden_size: 3072
|
||||||
|
adm_in_channels: 256
|
||||||
|
num_attention_heads: 48
|
||||||
|
|
||||||
|
transformer_args:
|
||||||
|
checkpoint_activations: True
|
||||||
|
vocab_size: 1
|
||||||
|
max_sequence_length: 64
|
||||||
|
layernorm_order: pre
|
||||||
|
skip_init: false
|
||||||
|
model_parallel_size: 1
|
||||||
|
is_decoder: false
|
||||||
|
|
||||||
|
modules:
|
||||||
|
pos_embed_config:
|
||||||
|
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
|
||||||
|
params:
|
||||||
|
learnable_pos_embed: True
|
||||||
|
hidden_size_head: 64
|
||||||
|
text_length: 226
|
||||||
|
|
||||||
|
patch_embed_config:
|
||||||
|
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||||
|
params:
|
||||||
|
text_hidden_size: 4096
|
||||||
|
|
||||||
|
adaln_layer_config:
|
||||||
|
target: dit_video_concat.AdaLNMixin
|
||||||
|
params:
|
||||||
|
qk_ln: True
|
||||||
|
|
||||||
|
final_layer_config:
|
||||||
|
target: dit_video_concat.FinalLayerMixin
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: false
|
||||||
|
input_key: txt
|
||||||
|
ucg_rate: 0.1
|
||||||
|
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||||
|
params:
|
||||||
|
model_dir: "t5-v1_1-xxl"
|
||||||
|
max_length: 226
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||||
|
params:
|
||||||
|
cp_size: 1
|
||||||
|
ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt"
|
||||||
|
ignore_keys: ['loss']
|
||||||
|
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
regularizer_config:
|
||||||
|
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||||
|
|
||||||
|
encoder_config:
|
||||||
|
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||||
|
params:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 2, 4]
|
||||||
|
attn_resolutions: []
|
||||||
|
num_res_blocks: 3
|
||||||
|
dropout: 0.0
|
||||||
|
gather_norm: True
|
||||||
|
|
||||||
|
decoder_config:
|
||||||
|
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||||
|
params:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 2, 4]
|
||||||
|
attn_resolutions: []
|
||||||
|
num_res_blocks: 3
|
||||||
|
dropout: 0.0
|
||||||
|
gather_norm: True
|
||||||
|
|
||||||
|
loss_fn_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||||
|
params:
|
||||||
|
fixed_frames: 0
|
||||||
|
offset_noise_level: 0
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||||
|
params:
|
||||||
|
uniform_sampling: True
|
||||||
|
num_idx: 1000
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||||
|
params:
|
||||||
|
shift_scale: 1.0
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||||
|
params:
|
||||||
|
fixed_frames: 0
|
||||||
|
num_steps: 50
|
||||||
|
verbose: True
|
||||||
|
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||||
|
params:
|
||||||
|
shift_scale: 1.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||||
|
params:
|
||||||
|
scale: 6
|
||||||
|
exp: 5
|
||||||
|
num_steps: 50
|
165
sat/configs/cogvideox_5b_i2v_lora.yaml
Normal file
165
sat/configs/cogvideox_5b_i2v_lora.yaml
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
model:
|
||||||
|
scale_factor: 0.7
|
||||||
|
disable_first_stage_autocast: true
|
||||||
|
latent_input: false
|
||||||
|
noised_image_input: true
|
||||||
|
noised_image_dropout: 0.05
|
||||||
|
not_trainable_prefixes: ['all'] ## Using Lora
|
||||||
|
log_keys:
|
||||||
|
- txt
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
|
params:
|
||||||
|
num_idx: 1000
|
||||||
|
quantize_c_noise: False
|
||||||
|
|
||||||
|
weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||||
|
params:
|
||||||
|
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: dit_video_concat.DiffusionTransformer
|
||||||
|
params:
|
||||||
|
time_embed_dim: 512
|
||||||
|
elementwise_affine: True
|
||||||
|
num_frames: 49
|
||||||
|
time_compressed_rate: 4
|
||||||
|
latent_width: 90
|
||||||
|
latent_height: 60
|
||||||
|
num_layers: 42
|
||||||
|
patch_size: 2
|
||||||
|
in_channels: 32
|
||||||
|
out_channels: 16
|
||||||
|
hidden_size: 3072
|
||||||
|
adm_in_channels: 256
|
||||||
|
num_attention_heads: 48
|
||||||
|
|
||||||
|
transformer_args:
|
||||||
|
checkpoint_activations: True
|
||||||
|
vocab_size: 1
|
||||||
|
max_sequence_length: 64
|
||||||
|
layernorm_order: pre
|
||||||
|
skip_init: false
|
||||||
|
model_parallel_size: 1
|
||||||
|
is_decoder: false
|
||||||
|
|
||||||
|
modules:
|
||||||
|
pos_embed_config:
|
||||||
|
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
|
||||||
|
params:
|
||||||
|
learnable_pos_embed: True
|
||||||
|
hidden_size_head: 64
|
||||||
|
text_length: 226
|
||||||
|
|
||||||
|
lora_config:
|
||||||
|
target: sat.model.finetune.lora2.LoraMixin
|
||||||
|
params:
|
||||||
|
r: 256
|
||||||
|
|
||||||
|
patch_embed_config:
|
||||||
|
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||||
|
params:
|
||||||
|
text_hidden_size: 4096
|
||||||
|
|
||||||
|
adaln_layer_config:
|
||||||
|
target: dit_video_concat.AdaLNMixin
|
||||||
|
params:
|
||||||
|
qk_ln: True
|
||||||
|
|
||||||
|
final_layer_config:
|
||||||
|
target: dit_video_concat.FinalLayerMixin
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: false
|
||||||
|
input_key: txt
|
||||||
|
ucg_rate: 0.1
|
||||||
|
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||||
|
params:
|
||||||
|
model_dir: "t5-v1_1-xxl"
|
||||||
|
max_length: 226
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||||
|
params:
|
||||||
|
cp_size: 1
|
||||||
|
ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt"
|
||||||
|
ignore_keys: [ 'loss' ]
|
||||||
|
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
regularizer_config:
|
||||||
|
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||||
|
|
||||||
|
encoder_config:
|
||||||
|
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||||
|
params:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 2, 4 ]
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
num_res_blocks: 3
|
||||||
|
dropout: 0.0
|
||||||
|
gather_norm: True
|
||||||
|
|
||||||
|
decoder_config:
|
||||||
|
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||||
|
params:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 2, 4 ]
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
num_res_blocks: 3
|
||||||
|
dropout: 0.0
|
||||||
|
gather_norm: True
|
||||||
|
|
||||||
|
loss_fn_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||||
|
params:
|
||||||
|
fixed_frames: 0
|
||||||
|
offset_noise_level: 0
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||||
|
params:
|
||||||
|
uniform_sampling: True
|
||||||
|
num_idx: 1000
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||||
|
params:
|
||||||
|
shift_scale: 1.0
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||||
|
params:
|
||||||
|
fixed_frames: 0
|
||||||
|
num_steps: 50
|
||||||
|
verbose: True
|
||||||
|
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||||
|
params:
|
||||||
|
shift_scale: 1.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||||
|
params:
|
||||||
|
scale: 6
|
||||||
|
exp: 5
|
||||||
|
num_steps: 50
|
@ -1,15 +1,16 @@
|
|||||||
args:
|
args:
|
||||||
|
image2video: False # True for image2video, False for text2video
|
||||||
latent_channels: 16
|
latent_channels: 16
|
||||||
mode: inference
|
mode: inference
|
||||||
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
|
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
|
||||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||||
|
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
input_type: txt
|
input_type: txt
|
||||||
input_file: configs/test.txt
|
input_file: configs/test.txt
|
||||||
|
sampling_image_size: [480, 720]
|
||||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||||
sampling_fps: 8
|
sampling_fps: 8
|
||||||
fp16: True # For CogVideoX-2B
|
# fp16: True # For CogVideoX-2B
|
||||||
# bf16: True # For CogVideoX-5B
|
bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V
|
||||||
output_dir: outputs/
|
output_dir: outputs/
|
||||||
force_inference: True
|
force_inference: True
|
@ -1,15 +1,15 @@
|
|||||||
args:
|
args:
|
||||||
checkpoint_activations: True ## using gradient checkpointing
|
checkpoint_activations: True # using gradient checkpointing
|
||||||
model_parallel_size: 1
|
model_parallel_size: 1
|
||||||
experiment_name: lora-disney
|
experiment_name: lora-disney
|
||||||
mode: finetune
|
mode: finetune
|
||||||
load: "cogvideox-2b-sat/transformer"
|
load: "{your CogVideoX SAT folder}/transformer"
|
||||||
no_load_rng: True
|
no_load_rng: True
|
||||||
train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough
|
train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough
|
||||||
eval_iters: 1
|
eval_iters: 1
|
||||||
eval_interval: 100
|
eval_interval: 100
|
||||||
eval_batch_size: 1
|
eval_batch_size: 1
|
||||||
save: ckpts_2b_lora
|
save: ckpts_5b_lora
|
||||||
save_interval: 500
|
save_interval: 500
|
||||||
log_interval: 20
|
log_interval: 20
|
||||||
train_data: [ "disney" ] # Train data path
|
train_data: [ "disney" ] # Train data path
|
||||||
@ -28,7 +28,7 @@ data:
|
|||||||
skip_frms_num: 3.
|
skip_frms_num: 3.
|
||||||
|
|
||||||
deepspeed:
|
deepspeed:
|
||||||
# Minimun for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
|
# Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs
|
||||||
train_micro_batch_size_per_gpu: 2
|
train_micro_batch_size_per_gpu: 2
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
steps_per_print: 50
|
steps_per_print: 50
|
||||||
@ -44,9 +44,9 @@ deepspeed:
|
|||||||
load_from_fp32_weights: false
|
load_from_fp32_weights: false
|
||||||
zero_allow_untested_optimizer: true
|
zero_allow_untested_optimizer: true
|
||||||
bf16:
|
bf16:
|
||||||
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
|
enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
|
||||||
fp16:
|
fp16:
|
||||||
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
|
enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
|
||||||
loss_scale: 0
|
loss_scale: 0
|
||||||
loss_scale_window: 400
|
loss_scale_window: 400
|
||||||
hysteresis: 2
|
hysteresis: 2
|
||||||
@ -55,7 +55,7 @@ deepspeed:
|
|||||||
optimizer:
|
optimizer:
|
||||||
type: sat.ops.FusedEmaAdam
|
type: sat.ops.FusedEmaAdam
|
||||||
params:
|
params:
|
||||||
lr: 0.001 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT
|
lr: 0.00001 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT
|
||||||
betas: [ 0.9, 0.95 ]
|
betas: [ 0.9, 0.95 ]
|
||||||
eps: 1e-8
|
eps: 1e-8
|
||||||
weight_decay: 1e-4
|
weight_decay: 1e-4
|
||||||
|
@ -362,7 +362,7 @@ class SFTDataset(Dataset):
|
|||||||
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
||||||
"""
|
"""
|
||||||
super(SFTDataset, self).__init__()
|
super(SFTDataset, self).__init__()
|
||||||
|
|
||||||
self.video_size = video_size
|
self.video_size = video_size
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
self.max_num_frames = max_num_frames
|
self.max_num_frames = max_num_frames
|
||||||
@ -385,7 +385,6 @@ class SFTDataset(Dataset):
|
|||||||
self.captions.append(caption)
|
self.captions.append(caption)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
decord.bridge.set_bridge("torch")
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
video_path = self.video_paths[index]
|
video_path = self.video_paths[index]
|
||||||
@ -411,9 +410,7 @@ class SFTDataset(Dataset):
|
|||||||
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
||||||
temp_frms = vr.get_batch(np.arange(start, end))
|
temp_frms = vr.get_batch(np.arange(start, end))
|
||||||
assert temp_frms is not None
|
assert temp_frms is not None
|
||||||
tensor_frms = (
|
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
|
||||||
)
|
|
||||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -426,15 +423,11 @@ class SFTDataset(Dataset):
|
|||||||
|
|
||||||
start = int(self.skip_frms_num)
|
start = int(self.skip_frms_num)
|
||||||
end = int(ori_vlen - self.skip_frms_num)
|
end = int(ori_vlen - self.skip_frms_num)
|
||||||
num_frames = nearest_smaller_4k_plus_1(
|
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
|
||||||
end - start
|
|
||||||
) # 3D VAE requires the number of frames to be 4k+1
|
|
||||||
end = int(start + num_frames)
|
end = int(start + num_frames)
|
||||||
temp_frms = vr.get_batch(np.arange(start, end))
|
temp_frms = vr.get_batch(np.arange(start, end))
|
||||||
assert temp_frms is not None
|
assert temp_frms is not None
|
||||||
tensor_frms = (
|
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
|
||||||
)
|
|
||||||
|
|
||||||
tensor_frms = pad_last_frame(
|
tensor_frms = pad_last_frame(
|
||||||
tensor_frms, self.max_num_frames
|
tensor_frms, self.max_num_frames
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
@ -130,6 +132,13 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
loss_dict = {"loss": loss_mean}
|
loss_dict = {"loss": loss_mean}
|
||||||
return loss_mean, loss_dict
|
return loss_mean, loss_dict
|
||||||
|
|
||||||
|
def add_noise_to_first_frame(self, image):
|
||||||
|
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(self.device)
|
||||||
|
sigma = torch.exp(sigma).to(image.dtype)
|
||||||
|
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
||||||
|
image = image + image_noise
|
||||||
|
return image
|
||||||
|
|
||||||
def shared_step(self, batch: Dict) -> Any:
|
def shared_step(self, batch: Dict) -> Any:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
if self.lr_scale is not None:
|
if self.lr_scale is not None:
|
||||||
@ -139,8 +148,22 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
batch["lr_input"] = lr_z
|
batch["lr_input"] = lr_z
|
||||||
|
|
||||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
if self.noised_image_input:
|
||||||
|
image = x[:, :, 0:1]
|
||||||
|
image = self.add_noise_to_first_frame(image)
|
||||||
|
image = self.encode_first_stage(image, batch)
|
||||||
|
|
||||||
x = self.encode_first_stage(x, batch)
|
x = self.encode_first_stage(x, batch)
|
||||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
if self.noised_image_input:
|
||||||
|
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
if self.noised_image_all_concat:
|
||||||
|
image = image.repeat(1, x.shape[1], 1, 1, 1)
|
||||||
|
else:
|
||||||
|
image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1)
|
||||||
|
if random.random() < self.noised_image_dropout:
|
||||||
|
image = torch.zeros_like(image)
|
||||||
|
batch["concat_images"] = image
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -300,8 +323,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
if isinstance(c[k], torch.Tensor):
|
if isinstance(c[k], torch.Tensor):
|
||||||
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
||||||
|
|
||||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
|
||||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
|
||||||
if self.noised_image_input:
|
if self.noised_image_input:
|
||||||
image = x[:, :, 0:1]
|
image = x[:, :, 0:1]
|
||||||
image = self.add_noise_to_first_frame(image)
|
image = self.add_noise_to_first_frame(image)
|
||||||
@ -320,6 +342,8 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
log["samples"] = samples
|
log["samples"] = samples
|
||||||
else:
|
else:
|
||||||
|
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||||
|
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
if only_log_video_latents:
|
if only_log_video_latents:
|
||||||
latents = 1.0 / self.scale_factor * samples
|
latents = 1.0 / self.scale_factor * samples
|
||||||
log["latents"] = latents
|
log["latents"] = latents
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
#! /bin/bash
|
#! /bin/bash
|
||||||
|
|
||||||
echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
||||||
|
|
||||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_5b_i2v_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||||
|
|
||||||
echo ${run_cmd}
|
echo ${run_cmd}
|
||||||
eval ${run_cmd}
|
eval ${run_cmd}
|
||||||
|
@ -547,84 +547,3 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
print("Missing keys: ", missing_keys)
|
print("Missing keys: ", missing_keys)
|
||||||
print("Unexpected keys: ", unexpected_keys)
|
print("Unexpected keys: ", unexpected_keys)
|
||||||
print(f"Restored from {path}")
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
|
|
||||||
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cp_size=0,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.cp_size = cp_size
|
|
||||||
return super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
return_reg_log: bool = False,
|
|
||||||
unregularized: bool = False,
|
|
||||||
input_cp: bool = False,
|
|
||||||
output_cp: bool = False,
|
|
||||||
use_cp: bool = True,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
|
||||||
if self.cp_size <= 1:
|
|
||||||
use_cp = False
|
|
||||||
if self.cp_size > 0 and use_cp and not input_cp:
|
|
||||||
if not is_context_parallel_initialized:
|
|
||||||
initialize_context_parallel(self.cp_size)
|
|
||||||
|
|
||||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
|
||||||
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
|
|
||||||
|
|
||||||
x = _conv_split(x, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
if return_reg_log:
|
|
||||||
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
|
||||||
else:
|
|
||||||
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
|
||||||
|
|
||||||
if self.cp_size > 0 and use_cp and not output_cp:
|
|
||||||
z = _conv_gather(z, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
if return_reg_log:
|
|
||||||
return z, reg_log
|
|
||||||
return z
|
|
||||||
|
|
||||||
def decode(
|
|
||||||
self,
|
|
||||||
z: torch.Tensor,
|
|
||||||
input_cp: bool = False,
|
|
||||||
output_cp: bool = False,
|
|
||||||
use_cp: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if self.cp_size <= 1:
|
|
||||||
use_cp = False
|
|
||||||
if self.cp_size > 0 and use_cp and not input_cp:
|
|
||||||
if not is_context_parallel_initialized:
|
|
||||||
initialize_context_parallel(self.cp_size)
|
|
||||||
|
|
||||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
|
||||||
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
|
|
||||||
|
|
||||||
z = _conv_split(z, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
x = super().decode(z, use_cp=use_cp, **kwargs)
|
|
||||||
|
|
||||||
if self.cp_size > 0 and use_cp and not output_cp:
|
|
||||||
x = _conv_gather(x, dim=2, kernel_size=1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
input_cp: bool = False,
|
|
||||||
latent_cp: bool = False,
|
|
||||||
output_cp: bool = False,
|
|
||||||
**additional_decode_kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
|
||||||
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
|
||||||
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
|
||||||
return z, dec, reg_log
|
|
||||||
|
@ -6,7 +6,6 @@ from ..util import (
|
|||||||
get_context_parallel_group,
|
get_context_parallel_group,
|
||||||
get_context_parallel_rank,
|
get_context_parallel_rank,
|
||||||
get_context_parallel_world_size,
|
get_context_parallel_world_size,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_USE_CP = True
|
_USE_CP = True
|
||||||
@ -179,4 +178,4 @@ def _conv_gather(input_, dim, kernel_size):
|
|||||||
|
|
||||||
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -100,8 +100,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "concat_images" in batch.keys():
|
if "concat_images" in batch.keys():
|
||||||
additional_model_inputs["concat_images"] = batch["concat_images"]
|
cond["concat"] = batch["concat_images"]
|
||||||
|
|
||||||
|
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
|
||||||
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
||||||
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
||||||
|
|
||||||
@ -117,11 +118,3 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
|||||||
elif self.type == "lpips":
|
elif self.type == "lpips":
|
||||||
loss = self.lpips(model_output, target).reshape(-1)
|
loss = self.lpips(model_output, target).reshape(-1)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def get_3d_position_ids(frame_len, h, w):
|
|
||||||
i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
|
|
||||||
j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
|
|
||||||
k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
|
|
||||||
position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
|
|
||||||
return position_ids
|
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
This script demonstrates how to convert and generate video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
This script demonstrates how to convert and generate video from a text prompt
|
||||||
|
using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
||||||
Note:
|
This script requires the `diffusers>=0.30.2` library to be installed.
|
||||||
This script requires the `diffusers>=0.30.1` library to be installed.
|
|
||||||
|
|
||||||
Run the script:
|
|
||||||
$ python convert_and_generate.py --transformer_ckpt_path <path_to_transformer_checkpoint> --vae_ckpt_path <path_to_vae_checkpoint> --output_path <path_to_output_directory> --text_encoder_path <path_to_t5>
|
|
||||||
|
|
||||||
Functions:
|
Functions:
|
||||||
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
|
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
|
||||||
@ -27,7 +23,13 @@ from typing import Any, Dict
|
|||||||
import torch
|
import torch
|
||||||
from transformers import T5EncoderModel, T5Tokenizer
|
from transformers import T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
from diffusers import (
|
||||||
|
AutoencoderKLCogVideoX,
|
||||||
|
CogVideoXDDIMScheduler,
|
||||||
|
CogVideoXImageToVideoPipeline,
|
||||||
|
CogVideoXPipeline,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
@ -101,6 +103,7 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
|||||||
"mixins.final_layer.norm_final": "norm_out.norm",
|
"mixins.final_layer.norm_final": "norm_out.norm",
|
||||||
"mixins.final_layer.linear": "proj_out",
|
"mixins.final_layer.linear": "proj_out",
|
||||||
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
|
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||||
|
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||||
@ -154,15 +157,18 @@ def convert_transformer(
|
|||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_attention_heads: int,
|
num_attention_heads: int,
|
||||||
use_rotary_positional_embeddings: bool,
|
use_rotary_positional_embeddings: bool,
|
||||||
|
i2v: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
PREFIX_KEY = "model.diffusion_model."
|
PREFIX_KEY = "model.diffusion_model."
|
||||||
|
|
||||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||||
transformer = CogVideoXTransformer3DModel(
|
transformer = CogVideoXTransformer3DModel(
|
||||||
|
in_channels=32 if i2v else 16,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||||
|
use_learned_positional_embeddings=i2v,
|
||||||
).to(dtype=dtype)
|
).to(dtype=dtype)
|
||||||
|
|
||||||
for key in list(original_state_dict.keys()):
|
for key in list(original_state_dict.keys()):
|
||||||
@ -176,7 +182,6 @@ def convert_transformer(
|
|||||||
if special_key not in key:
|
if special_key not in key:
|
||||||
continue
|
continue
|
||||||
handler_fn_inplace(key, original_state_dict)
|
handler_fn_inplace(key, original_state_dict)
|
||||||
|
|
||||||
transformer.load_state_dict(original_state_dict, strict=True)
|
transformer.load_state_dict(original_state_dict, strict=True)
|
||||||
return transformer
|
return transformer
|
||||||
|
|
||||||
@ -204,8 +209,7 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint")
|
||||||
)
|
|
||||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||||
@ -228,6 +232,7 @@ def get_args():
|
|||||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||||
|
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -248,6 +253,7 @@ if __name__ == "__main__":
|
|||||||
args.num_layers,
|
args.num_layers,
|
||||||
args.num_attention_heads,
|
args.num_attention_heads,
|
||||||
args.use_rotary_positional_embeddings,
|
args.use_rotary_positional_embeddings,
|
||||||
|
args.i2v,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
if args.vae_ckpt_path is not None:
|
if args.vae_ckpt_path is not None:
|
||||||
@ -256,8 +262,7 @@ if __name__ == "__main__":
|
|||||||
text_encoder_id = "google/t5-v1_1-xxl"
|
text_encoder_id = "google/t5-v1_1-xxl"
|
||||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||||
|
# Apparently, the conversion does not work anymore without this :shrug:
|
||||||
# Apparently, the conversion does not work any more without this :shrug:
|
|
||||||
for param in text_encoder.parameters():
|
for param in text_encoder.parameters():
|
||||||
param.data = param.data.contiguous()
|
param.data = param.data.contiguous()
|
||||||
|
|
||||||
@ -275,9 +280,17 @@ if __name__ == "__main__":
|
|||||||
"timestep_spacing": "trailing",
|
"timestep_spacing": "trailing",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if args.i2v:
|
||||||
|
pipeline_cls = CogVideoXImageToVideoPipeline
|
||||||
|
else:
|
||||||
|
pipeline_cls = CogVideoXPipeline
|
||||||
|
|
||||||
pipe = CogVideoXPipeline(
|
pipe = pipeline_cls(
|
||||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
vae=vae,
|
||||||
|
transformer=transformer,
|
||||||
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user