diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py index 0cec3f5..45d3466 100644 --- a/tools/convert_weight_sat2hf.py +++ b/tools/convert_weight_sat2hf.py @@ -25,11 +25,11 @@ import argparse from typing import Any, Dict import torch -from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel from transformers import T5EncoderModel, T5Tokenizer +from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel + -# Function to reassign the query, key, and value weights in-place def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): to_q_key = key.replace("query_key_value", "to_q") to_k_key = key.replace("query_key_value", "to_k") @@ -41,7 +41,6 @@ def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): state_dict.pop(key) -# Function to reassign layer normalization for query and key in-place def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): layer_id, weight_or_bias = key.split(".")[-2:] @@ -53,7 +52,6 @@ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): state_dict[new_key] = state_dict.pop(key) -# Function to reassign adaptive layer normalization in-place def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): layer_id, _, weight_or_bias = key.split(".")[-3:] @@ -70,12 +68,10 @@ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): state_dict.pop(key) -# Function to remove keys from state_dict in-place def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): state_dict.pop(key) -# Function to replace keys in the "up" block in-place def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): key_split = key.split(".") layer_index = int(key_split[2]) @@ -88,7 +84,6 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): state_dict[new_key] = state_dict.pop(key) -# Dictionary for renaming transformer keys TRANSFORMER_KEYS_RENAME_DICT = { "transformer.final_layernorm": "norm_final", "transformer": "transformer_blocks", @@ -108,16 +103,17 @@ TRANSFORMER_KEYS_RENAME_DICT = { "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", } -# Dictionary for handling special keys in transformer TRANSFORMER_SPECIAL_KEYS_REMAP = { "query_key_value": reassign_query_key_value_inplace, "query_layernorm_list": reassign_query_key_layernorm_inplace, "key_layernorm_list": reassign_query_key_layernorm_inplace, "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, "embed_tokens": remove_keys_inplace, + "freqs_sin": remove_keys_inplace, + "freqs_cos": remove_keys_inplace, + "position_embedding": remove_keys_inplace, } -# Dictionary for renaming VAE keys VAE_KEYS_RENAME_DICT = { "block.": "resnets.", "down.": "down_blocks.", @@ -130,17 +126,14 @@ VAE_KEYS_RENAME_DICT = { "decoder.mid.block_2": "decoder.mid_block.resnets.1", } -# Dictionary for handling special keys in VAE VAE_SPECIAL_KEYS_REMAP = { "loss": remove_keys_inplace, "up.": replace_up_keys_inplace, } -# Maximum length of the tokenizer (Must be 226) TOKENIZER_MAX_LENGTH = 226 -# Function to extract the state_dict from a saved checkpoint def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = saved_dict if "model" in saved_dict.keys(): @@ -152,17 +145,25 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: return state_dict -# Function to update the state_dict with new key assignments in-place def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: state_dict[new_key] = state_dict.pop(old_key) -# Function to convert a transformer checkpoint to the CogVideoX format -def convert_transformer(ckpt_path: str): +def convert_transformer( + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + dtype: torch.dtype, +): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - transformer = CogVideoXTransformer3DModel() + transformer = CogVideoXTransformer3DModel( + num_layers=num_layers, + num_attention_heads=num_attention_heads, + use_rotary_positional_embeddings=use_rotary_positional_embeddings, + ).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[len(PREFIX_KEY) :] @@ -180,10 +181,9 @@ def convert_transformer(ckpt_path: str): return transformer -# Function to convert a VAE checkpoint to the CogVideoX format -def convert_vae(ckpt_path: str): +def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - vae = AutoencoderKLCogVideoX() + vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[:] @@ -201,7 +201,6 @@ def convert_vae(ckpt_path: str): return vae -# Function to parse command-line arguments for the script def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -209,23 +208,26 @@ def get_args(): ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument( - "--text_encoder_path", - type=str, - required=True, - default="google/t5-v1_1-xxl", - help="Path where converted model should be saved", - ) - parser.add_argument( - "--text_encoder_cache_dir", - type=str, - default=None, - help="Path to text encoder cache directory. Not needed if text_encoder_path is in your local.", - ) - parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") + parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" ) + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) + # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 + parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") + # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 + parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") + # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True + parser.add_argument( + "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" + ) + # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 + parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") + # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 + parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") return parser.parse_args() @@ -235,17 +237,33 @@ if __name__ == "__main__": transformer = None vae = None - if args.transformer_ckpt_path is not None: - transformer = convert_transformer(args.transformer_ckpt_path) - if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path) + if args.fp16 and args.bf16: + raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") - tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_path, model_max_length=TOKENIZER_MAX_LENGTH) - text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, cache_dir=args.text_encoder_cache_dir) + dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer( + args.transformer_ckpt_path, + args.num_layers, + args.num_attention_heads, + args.use_rotary_positional_embeddings, + dtype, + ) + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # Apparently, the conversion does not work any more without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() scheduler = CogVideoXDDIMScheduler.from_config( { - "snr_shift_scale": 3.0, + "snr_shift_scale": args.snr_shift_scale, "beta_end": 0.012, "beta_schedule": "scaled_linear", "beta_start": 0.00085, @@ -254,7 +272,7 @@ if __name__ == "__main__": "prediction_type": "v_prediction", "rescale_betas_zero_snr": True, "set_alpha_to_one": True, - "timestep_spacing": "linspace", + "timestep_spacing": "trailing", } ) @@ -264,5 +282,10 @@ if __name__ == "__main__": if args.fp16: pipe = pipe.to(dtype=torch.float16) + if args.bf16: + pipe = pipe.to(dtype=torch.bfloat16) + # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird + # for users to specify variant when the default is not fp32 and they want to run with the correct default (which + # is either fp16/bf16 here). pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)