Merge pull request #253 from THUDM/CogVideoX_dev

Update convert_weight_sat2hf.py
This commit is contained in:
Yuxuan.Zhang 2024-09-08 09:29:48 +08:00 committed by GitHub
commit 98466e674c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,11 +25,11 @@ import argparse
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from transformers import T5EncoderModel, T5Tokenizer from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
# Function to reassign the query, key, and value weights in-place
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
to_q_key = key.replace("query_key_value", "to_q") to_q_key = key.replace("query_key_value", "to_q")
to_k_key = key.replace("query_key_value", "to_k") to_k_key = key.replace("query_key_value", "to_k")
@ -41,7 +41,6 @@ def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key) state_dict.pop(key)
# Function to reassign layer normalization for query and key in-place
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, weight_or_bias = key.split(".")[-2:] layer_id, weight_or_bias = key.split(".")[-2:]
@ -53,7 +52,6 @@ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
state_dict[new_key] = state_dict.pop(key) state_dict[new_key] = state_dict.pop(key)
# Function to reassign adaptive layer normalization in-place
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, _, weight_or_bias = key.split(".")[-3:] layer_id, _, weight_or_bias = key.split(".")[-3:]
@ -70,12 +68,10 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key) state_dict.pop(key)
# Function to remove keys from state_dict in-place
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key) state_dict.pop(key)
# Function to replace keys in the "up" block in-place
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
key_split = key.split(".") key_split = key.split(".")
layer_index = int(key_split[2]) layer_index = int(key_split[2])
@ -88,7 +84,6 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict[new_key] = state_dict.pop(key) state_dict[new_key] = state_dict.pop(key)
# Dictionary for renaming transformer keys
TRANSFORMER_KEYS_RENAME_DICT = { TRANSFORMER_KEYS_RENAME_DICT = {
"transformer.final_layernorm": "norm_final", "transformer.final_layernorm": "norm_final",
"transformer": "transformer_blocks", "transformer": "transformer_blocks",
@ -108,16 +103,17 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear", "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
} }
# Dictionary for handling special keys in transformer
TRANSFORMER_SPECIAL_KEYS_REMAP = { TRANSFORMER_SPECIAL_KEYS_REMAP = {
"query_key_value": reassign_query_key_value_inplace, "query_key_value": reassign_query_key_value_inplace,
"query_layernorm_list": reassign_query_key_layernorm_inplace, "query_layernorm_list": reassign_query_key_layernorm_inplace,
"key_layernorm_list": reassign_query_key_layernorm_inplace, "key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace, "embed_tokens": remove_keys_inplace,
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
} }
# Dictionary for renaming VAE keys
VAE_KEYS_RENAME_DICT = { VAE_KEYS_RENAME_DICT = {
"block.": "resnets.", "block.": "resnets.",
"down.": "down_blocks.", "down.": "down_blocks.",
@ -130,17 +126,14 @@ VAE_KEYS_RENAME_DICT = {
"decoder.mid.block_2": "decoder.mid_block.resnets.1", "decoder.mid.block_2": "decoder.mid_block.resnets.1",
} }
# Dictionary for handling special keys in VAE
VAE_SPECIAL_KEYS_REMAP = { VAE_SPECIAL_KEYS_REMAP = {
"loss": remove_keys_inplace, "loss": remove_keys_inplace,
"up.": replace_up_keys_inplace, "up.": replace_up_keys_inplace,
} }
# Maximum length of the tokenizer (Must be 226)
TOKENIZER_MAX_LENGTH = 226 TOKENIZER_MAX_LENGTH = 226
# Function to extract the state_dict from a saved checkpoint
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict state_dict = saved_dict
if "model" in saved_dict.keys(): if "model" in saved_dict.keys():
@ -152,17 +145,25 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict return state_dict
# Function to update the state_dict with new key assignments in-place
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
# Function to convert a transformer checkpoint to the CogVideoX format def convert_transformer(
def convert_transformer(ckpt_path: str): ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel() transformer = CogVideoXTransformer3DModel(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :] new_key = key[len(PREFIX_KEY) :]
@ -180,10 +181,9 @@ def convert_transformer(ckpt_path: str):
return transformer return transformer
# Function to convert a VAE checkpoint to the CogVideoX format def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
def convert_vae(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX() vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[:] new_key = key[:]
@ -201,7 +201,6 @@ def convert_vae(ckpt_path: str):
return vae return vae
# Function to parse command-line arguments for the script
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -209,23 +208,26 @@ def get_args():
) )
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument( parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
"--text_encoder_path", parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
type=str,
required=True,
default="google/t5-v1_1-xxl",
help="Path where converted model should be saved",
)
parser.add_argument(
"--text_encoder_cache_dir",
type=str,
default=None,
help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.",
)
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
parser.add_argument( parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
) )
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument(
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
)
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
return parser.parse_args() return parser.parse_args()
@ -235,17 +237,33 @@ if __name__ == "__main__":
transformer = None transformer = None
vae = None vae = None
if args.transformer_ckpt_path is not None: if args.fp16 and args.bf16:
transformer = convert_transformer(args.transformer_ckpt_path) raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH) dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir)
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
dtype,
)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work any more without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config( scheduler = CogVideoXDDIMScheduler.from_config(
{ {
"snr_shift_scale": 3.0, "snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012, "beta_end": 0.012,
"beta_schedule": "scaled_linear", "beta_schedule": "scaled_linear",
"beta_start": 0.00085, "beta_start": 0.00085,
@ -254,7 +272,7 @@ if __name__ == "__main__":
"prediction_type": "v_prediction", "prediction_type": "v_prediction",
"rescale_betas_zero_snr": True, "rescale_betas_zero_snr": True,
"set_alpha_to_one": True, "set_alpha_to_one": True,
"timestep_spacing": "linspace", "timestep_spacing": "trailing",
} }
) )
@ -264,5 +282,10 @@ if __name__ == "__main__":
if args.fp16: if args.fp16:
pipe = pipe.to(dtype=torch.float16) pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)