Merge pull request #253 from THUDM/CogVideoX_dev

Update convert_weight_sat2hf.py
This commit is contained in:
Yuxuan.Zhang 2024-09-08 09:29:48 +08:00 committed by GitHub
commit 98466e674c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,11 +25,11 @@ import argparse
from typing import Any, Dict
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
# Function to reassign the query, key, and value weights in-place
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
to_q_key = key.replace("query_key_value", "to_q")
to_k_key = key.replace("query_key_value", "to_k")
@ -41,7 +41,6 @@ def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
# Function to reassign layer normalization for query and key in-place
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, weight_or_bias = key.split(".")[-2:]
@ -53,7 +52,6 @@ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
state_dict[new_key] = state_dict.pop(key)
# Function to reassign adaptive layer normalization in-place
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, _, weight_or_bias = key.split(".")[-3:]
@ -70,12 +68,10 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
# Function to remove keys from state_dict in-place
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
# Function to replace keys in the "up" block in-place
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
key_split = key.split(".")
layer_index = int(key_split[2])
@ -88,7 +84,6 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict[new_key] = state_dict.pop(key)
# Dictionary for renaming transformer keys
TRANSFORMER_KEYS_RENAME_DICT = {
"transformer.final_layernorm": "norm_final",
"transformer": "transformer_blocks",
@ -108,16 +103,17 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
}
# Dictionary for handling special keys in transformer
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"query_key_value": reassign_query_key_value_inplace,
"query_layernorm_list": reassign_query_key_layernorm_inplace,
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
}
# Dictionary for renaming VAE keys
VAE_KEYS_RENAME_DICT = {
"block.": "resnets.",
"down.": "down_blocks.",
@ -130,17 +126,14 @@ VAE_KEYS_RENAME_DICT = {
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
}
# Dictionary for handling special keys in VAE
VAE_SPECIAL_KEYS_REMAP = {
"loss": remove_keys_inplace,
"up.": replace_up_keys_inplace,
}
# Maximum length of the tokenizer (Must be 226)
TOKENIZER_MAX_LENGTH = 226
# Function to extract the state_dict from a saved checkpoint
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
@ -152,17 +145,25 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict
# Function to update the state_dict with new key assignments in-place
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
# Function to convert a transformer checkpoint to the CogVideoX format
def convert_transformer(ckpt_path: str):
def convert_transformer(
ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel()
transformer = CogVideoXTransformer3DModel(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
@ -180,10 +181,9 @@ def convert_transformer(ckpt_path: str):
return transformer
# Function to convert a VAE checkpoint to the CogVideoX format
def convert_vae(ckpt_path: str):
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX()
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
@ -201,7 +201,6 @@ def convert_vae(ckpt_path: str):
return vae
# Function to parse command-line arguments for the script
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -209,23 +208,26 @@ def get_args():
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument(
"--text_encoder_path",
type=str,
required=True,
default="google/t5-v1_1-xxl",
help="Path where converted model should be saved",
)
parser.add_argument(
"--text_encoder_cache_dir",
type=str,
default=None,
help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.",
)
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument(
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
)
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
return parser.parse_args()
@ -235,17 +237,33 @@ if __name__ == "__main__":
transformer = None
vae = None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
if args.fp16 and args.bf16:
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir)
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
dtype,
)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work any more without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 3.0,
"snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
@ -254,7 +272,7 @@ if __name__ == "__main__":
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
"timestep_spacing": "trailing",
}
)
@ -264,5 +282,10 @@ if __name__ == "__main__":
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)