CogVideo/tools/convert_weight_sat2hf.py
Yuxuan Zhang 39c6562dc8 format
2025-03-22 15:14:06 +08:00

404 lines
14 KiB
Python

"""
The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
This script supports the conversion of the following models:
- CogVideoX-2B
- CogVideoX-5B, CogVideoX-5B-I2V
- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
Original Script:
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
"""
import argparse
from typing import Any, Dict
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
to_q_key = key.replace("query_key_value", "to_q")
to_k_key = key.replace("query_key_value", "to_k")
to_v_key = key.replace("query_key_value", "to_v")
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
state_dict[to_q_key] = to_q
state_dict[to_k_key] = to_k
state_dict[to_v_key] = to_v
state_dict.pop(key)
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, weight_or_bias = key.split(".")[-2:]
if "query" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
elif "key" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
state_dict[new_key] = state_dict.pop(key)
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, _, weight_or_bias = key.split(".")[-3:]
weights_or_biases = state_dict[key].chunk(12, dim=0)
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
state_dict[norm1_key] = norm1_weights_or_biases
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
state_dict[norm2_key] = norm2_weights_or_biases
state_dict.pop(key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
key_split = key.split(".")
layer_index = int(key_split[2])
replace_layer_index = 4 - 1 - layer_index
key_split[1] = "up_blocks"
key_split[2] = str(replace_layer_index)
new_key = ".".join(key_split)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"transformer.final_layernorm": "norm_final",
"transformer": "transformer_blocks",
"attention": "attn1",
"mlp": "ff.net",
"dense_h_to_4h": "0.proj",
"dense_4h_to_h": "2",
".layers": "",
"dense": "to_out.0",
"input_layernorm": "norm1.norm",
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
"ofs_embed.0": "ofs_embedding.linear_1",
"ofs_embed.2": "ofs_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"query_key_value": reassign_query_key_value_inplace,
"query_layernorm_list": reassign_query_key_layernorm_inplace,
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
}
VAE_KEYS_RENAME_DICT = {
"block.": "resnets.",
"down.": "down_blocks.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"nin_shortcut": "conv_shortcut",
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
}
VAE_SPECIAL_KEYS_REMAP = {
"loss": remove_keys_inplace,
"up.": replace_up_keys_inplace,
}
TOKENIZER_MAX_LENGTH = 226
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_inplace(
state_dict: Dict[str, Any], old_key: str, new_key: str
) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(
ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
init_kwargs: Dict[str, Any],
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel(
in_channels=32 if i2v else 16,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
ofs_embed_dim=512
if (i2v and init_kwargs["patch_size_t"] is not None)
else None, # CogVideoX1.5-5B-I2V
use_learned_positional_embeddings=i2v
and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
**init_kwargs,
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
init_kwargs = {"scaling_factor": scaling_factor}
if version == "1.5":
init_kwargs.update({"invert_scale_latents": True})
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
return vae
def get_transformer_init_kwargs(version: str):
if version == "1.0":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": None,
"patch_bias": True,
"sample_height": 480 // vae_scale_factor_spatial,
"sample_width": 720 // vae_scale_factor_spatial,
"sample_frames": 49,
}
elif version == "1.5":
vae_scale_factor_spatial = 8
init_kwargs = {
"patch_size": 2,
"patch_size_t": 2,
"patch_bias": False,
"sample_height": 768 // vae_scale_factor_spatial,
"sample_width": 1360 // vae_scale_factor_spatial,
"sample_frames": 81,
}
else:
raise ValueError("Unsupported version of CogVideoX.")
return init_kwargs
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path",
type=str,
default=None,
help="Path to original transformer checkpoint",
)
parser.add_argument(
"--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Path where converted model should be saved"
)
parser.add_argument(
"--fp16",
action="store_true",
default=False,
help="Whether to save the model weights in fp16",
)
parser.add_argument(
"--bf16",
action="store_true",
default=False,
help="Whether to save the model weights in bf16",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
default=False,
help="Whether to push to HF Hub after saving",
)
parser.add_argument(
"--text_encoder_cache_dir",
type=str,
default=None,
help="Path to text encoder cache directory",
)
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument(
"--num_attention_heads", type=int, default=30, help="Number of attention heads"
)
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument(
"--use_rotary_positional_embeddings",
action="store_true",
default=False,
help="Whether to use RoPE or not",
)
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument(
"--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE"
)
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument(
"--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE"
)
parser.add_argument(
"--i2v",
action="store_true",
default=False,
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
)
parser.add_argument(
"--version",
choices=["1.0", "1.5"],
default="1.0",
help="Which version of CogVideoX to use for initializing default modeling parameters.",
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
transformer = None
vae = None
if args.fp16 and args.bf16:
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None:
init_kwargs = get_transformer_init_kwargs(args.version)
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
init_kwargs,
)
if args.vae_ckpt_path is not None:
# Keep VAE in float32 for better quality
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_id, cache_dir=args.text_encoder_cache_dir
)
if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "trailing",
}
)
if args.i2v:
pipeline_cls = CogVideoXImageToVideoPipeline
else:
pipeline_cls = CogVideoXPipeline
pipe = pipeline_cls(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
# This is necessary This is necessary for users with insufficient memory,
# such as those using Colab and notebooks, as it can save some memory used for model loading.
pipe.save_pretrained(
args.output_path,
safe_serialization=True,
max_shard_size="5GB",
push_to_hub=args.push_to_hub,
)