mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Merge pull request #253 from THUDM/CogVideoX_dev
Update convert_weight_sat2hf.py
This commit is contained in:
commit
98466e674c
@ -25,11 +25,11 @@ import argparse
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
|
||||||
from transformers import T5EncoderModel, T5Tokenizer
|
from transformers import T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||||
|
|
||||||
|
|
||||||
# Function to reassign the query, key, and value weights in-place
|
|
||||||
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
to_q_key = key.replace("query_key_value", "to_q")
|
to_q_key = key.replace("query_key_value", "to_q")
|
||||||
to_k_key = key.replace("query_key_value", "to_k")
|
to_k_key = key.replace("query_key_value", "to_k")
|
||||||
@ -41,7 +41,6 @@ def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
|||||||
state_dict.pop(key)
|
state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
# Function to reassign layer normalization for query and key in-place
|
|
||||||
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
|
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
layer_id, weight_or_bias = key.split(".")[-2:]
|
layer_id, weight_or_bias = key.split(".")[-2:]
|
||||||
|
|
||||||
@ -53,7 +52,6 @@ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
|
|||||||
state_dict[new_key] = state_dict.pop(key)
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
# Function to reassign adaptive layer normalization in-place
|
|
||||||
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
|
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
layer_id, _, weight_or_bias = key.split(".")[-3:]
|
layer_id, _, weight_or_bias = key.split(".")[-3:]
|
||||||
|
|
||||||
@ -70,12 +68,10 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
|
|||||||
state_dict.pop(key)
|
state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
# Function to remove keys from state_dict in-place
|
|
||||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
state_dict.pop(key)
|
state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
# Function to replace keys in the "up" block in-place
|
|
||||||
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
||||||
key_split = key.split(".")
|
key_split = key.split(".")
|
||||||
layer_index = int(key_split[2])
|
layer_index = int(key_split[2])
|
||||||
@ -88,7 +84,6 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
|||||||
state_dict[new_key] = state_dict.pop(key)
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
|
|
||||||
# Dictionary for renaming transformer keys
|
|
||||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||||
"transformer.final_layernorm": "norm_final",
|
"transformer.final_layernorm": "norm_final",
|
||||||
"transformer": "transformer_blocks",
|
"transformer": "transformer_blocks",
|
||||||
@ -108,16 +103,17 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
|||||||
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
|
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Dictionary for handling special keys in transformer
|
|
||||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||||
"query_key_value": reassign_query_key_value_inplace,
|
"query_key_value": reassign_query_key_value_inplace,
|
||||||
"query_layernorm_list": reassign_query_key_layernorm_inplace,
|
"query_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||||
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||||
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
||||||
"embed_tokens": remove_keys_inplace,
|
"embed_tokens": remove_keys_inplace,
|
||||||
|
"freqs_sin": remove_keys_inplace,
|
||||||
|
"freqs_cos": remove_keys_inplace,
|
||||||
|
"position_embedding": remove_keys_inplace,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Dictionary for renaming VAE keys
|
|
||||||
VAE_KEYS_RENAME_DICT = {
|
VAE_KEYS_RENAME_DICT = {
|
||||||
"block.": "resnets.",
|
"block.": "resnets.",
|
||||||
"down.": "down_blocks.",
|
"down.": "down_blocks.",
|
||||||
@ -130,17 +126,14 @@ VAE_KEYS_RENAME_DICT = {
|
|||||||
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
|
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Dictionary for handling special keys in VAE
|
|
||||||
VAE_SPECIAL_KEYS_REMAP = {
|
VAE_SPECIAL_KEYS_REMAP = {
|
||||||
"loss": remove_keys_inplace,
|
"loss": remove_keys_inplace,
|
||||||
"up.": replace_up_keys_inplace,
|
"up.": replace_up_keys_inplace,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Maximum length of the tokenizer (Must be 226)
|
|
||||||
TOKENIZER_MAX_LENGTH = 226
|
TOKENIZER_MAX_LENGTH = 226
|
||||||
|
|
||||||
|
|
||||||
# Function to extract the state_dict from a saved checkpoint
|
|
||||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
state_dict = saved_dict
|
state_dict = saved_dict
|
||||||
if "model" in saved_dict.keys():
|
if "model" in saved_dict.keys():
|
||||||
@ -152,17 +145,25 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
# Function to update the state_dict with new key assignments in-place
|
|
||||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||||
state_dict[new_key] = state_dict.pop(old_key)
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
|
|
||||||
# Function to convert a transformer checkpoint to the CogVideoX format
|
def convert_transformer(
|
||||||
def convert_transformer(ckpt_path: str):
|
ckpt_path: str,
|
||||||
|
num_layers: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
use_rotary_positional_embeddings: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
PREFIX_KEY = "model.diffusion_model."
|
PREFIX_KEY = "model.diffusion_model."
|
||||||
|
|
||||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||||
transformer = CogVideoXTransformer3DModel()
|
transformer = CogVideoXTransformer3DModel(
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||||
|
).to(dtype=dtype)
|
||||||
|
|
||||||
for key in list(original_state_dict.keys()):
|
for key in list(original_state_dict.keys()):
|
||||||
new_key = key[len(PREFIX_KEY) :]
|
new_key = key[len(PREFIX_KEY) :]
|
||||||
@ -180,10 +181,9 @@ def convert_transformer(ckpt_path: str):
|
|||||||
return transformer
|
return transformer
|
||||||
|
|
||||||
|
|
||||||
# Function to convert a VAE checkpoint to the CogVideoX format
|
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||||
def convert_vae(ckpt_path: str):
|
|
||||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||||
vae = AutoencoderKLCogVideoX()
|
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
|
||||||
|
|
||||||
for key in list(original_state_dict.keys()):
|
for key in list(original_state_dict.keys()):
|
||||||
new_key = key[:]
|
new_key = key[:]
|
||||||
@ -201,7 +201,6 @@ def convert_vae(ckpt_path: str):
|
|||||||
return vae
|
return vae
|
||||||
|
|
||||||
|
|
||||||
# Function to parse command-line arguments for the script
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -209,23 +208,26 @@ def get_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||||
parser.add_argument(
|
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||||
"--text_encoder_path",
|
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
default="google/t5-v1_1-xxl",
|
|
||||||
help="Path where converted model should be saved",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--text_encoder_cache_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||||
|
)
|
||||||
|
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||||
|
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||||
|
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||||
|
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
||||||
|
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
||||||
|
)
|
||||||
|
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
||||||
|
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||||
|
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||||
|
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -235,17 +237,33 @@ if __name__ == "__main__":
|
|||||||
transformer = None
|
transformer = None
|
||||||
vae = None
|
vae = None
|
||||||
|
|
||||||
if args.transformer_ckpt_path is not None:
|
if args.fp16 and args.bf16:
|
||||||
transformer = convert_transformer(args.transformer_ckpt_path)
|
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
|
||||||
if args.vae_ckpt_path is not None:
|
|
||||||
vae = convert_vae(args.vae_ckpt_path)
|
|
||||||
|
|
||||||
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH)
|
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
||||||
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir)
|
|
||||||
|
if args.transformer_ckpt_path is not None:
|
||||||
|
transformer = convert_transformer(
|
||||||
|
args.transformer_ckpt_path,
|
||||||
|
args.num_layers,
|
||||||
|
args.num_attention_heads,
|
||||||
|
args.use_rotary_positional_embeddings,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
if args.vae_ckpt_path is not None:
|
||||||
|
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||||
|
|
||||||
|
text_encoder_id = "google/t5-v1_1-xxl"
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||||
|
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||||
|
|
||||||
|
# Apparently, the conversion does not work any more without this :shrug:
|
||||||
|
for param in text_encoder.parameters():
|
||||||
|
param.data = param.data.contiguous()
|
||||||
|
|
||||||
scheduler = CogVideoXDDIMScheduler.from_config(
|
scheduler = CogVideoXDDIMScheduler.from_config(
|
||||||
{
|
{
|
||||||
"snr_shift_scale": 3.0,
|
"snr_shift_scale": args.snr_shift_scale,
|
||||||
"beta_end": 0.012,
|
"beta_end": 0.012,
|
||||||
"beta_schedule": "scaled_linear",
|
"beta_schedule": "scaled_linear",
|
||||||
"beta_start": 0.00085,
|
"beta_start": 0.00085,
|
||||||
@ -254,7 +272,7 @@ if __name__ == "__main__":
|
|||||||
"prediction_type": "v_prediction",
|
"prediction_type": "v_prediction",
|
||||||
"rescale_betas_zero_snr": True,
|
"rescale_betas_zero_snr": True,
|
||||||
"set_alpha_to_one": True,
|
"set_alpha_to_one": True,
|
||||||
"timestep_spacing": "linspace",
|
"timestep_spacing": "trailing",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -264,5 +282,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
pipe = pipe.to(dtype=torch.float16)
|
pipe = pipe.to(dtype=torch.float16)
|
||||||
|
if args.bf16:
|
||||||
|
pipe = pipe.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||||
|
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||||
|
# is either fp16/bf16 here).
|
||||||
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user