diff --git a/README.md b/README.md index 73705cd..b94bf00 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,7 @@ models we currently offer, along with their foundational information. Inference Precision - BF16 + BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported: INT4 FP16*(Recommended), BF16, FP32, FP8*, INT8, Not supported: INT4 BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported: INT4 diff --git a/README_ja.md b/README_ja.md index 26b02c1..a7aa11b 100644 --- a/README_ja.md +++ b/README_ja.md @@ -186,7 +186,7 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の 推論精度 - BF16 + BF16(推奨), FP16, FP32,FP8*,INT8,INT4非対応 FP16*(推奨), BF16, FP32,FP8*,INT8,INT4非対応 BF16(推奨), FP16, FP32,FP8*,INT8,INT4非対応 diff --git a/README_zh.md b/README_zh.md index f456376..704c467 100644 --- a/README_zh.md +++ b/README_zh.md @@ -176,7 +176,7 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源 推理精度 - BF16 + BF16(推荐), FP16, FP32,FP8*,INT8,不支持INT4 FP16*(推荐), BF16, FP32,FP8*,INT8,不支持INT4 BF16(推荐), FP16, FP32,FP8*,INT8,不支持INT4 diff --git a/inference/cli_demo.py b/inference/cli_demo.py index bc97dd8..a211b4b 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -103,16 +103,13 @@ def generate_video( # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference # and enable to("cuda") - pipe.to("cuda") - - # pipe.enable_sequential_cpu_offload() - + # pipe.to("cuda") + pipe.enable_sequential_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() # 4. Generate the video frames based on the prompt. # `num_frames` is the Number of frames to generate. - # This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames. if generate_type == "i2v": video_generate = pipe( height=height, diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py index f325018..b70af1a 100644 --- a/tools/convert_weight_sat2hf.py +++ b/tools/convert_weight_sat2hf.py @@ -92,6 +92,8 @@ TRANSFORMER_KEYS_RENAME_DICT = { "post_attn1_layernorm": "norm2.norm", "time_embed.0": "time_embedding.linear_1", "time_embed.2": "time_embedding.linear_2", + "ofs_embed.0": "ofs_embedding.linear_1", + "ofs_embed.2": "ofs_embedding.linear_2", "mixins.patch_embed": "patch_embed", "mixins.final_layer.norm_final": "norm_out.norm", "mixins.final_layer.linear": "proj_out", @@ -146,12 +148,13 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: def convert_transformer( - ckpt_path: str, - num_layers: int, - num_attention_heads: int, - use_rotary_positional_embeddings: bool, - i2v: bool, - dtype: torch.dtype, + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + i2v: bool, + dtype: torch.dtype, + init_kwargs: Dict[str, Any], ): PREFIX_KEY = "model.diffusion_model." @@ -161,11 +164,13 @@ def convert_transformer( num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, - use_learned_positional_embeddings=i2v, + ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V + use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V + **init_kwargs, ).to(dtype=dtype) for key in list(original_state_dict.keys()): - new_key = key[len(PREFIX_KEY):] + new_key = key[len(PREFIX_KEY) :] for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_inplace(original_state_dict, key, new_key) @@ -175,13 +180,18 @@ def convert_transformer( if special_key not in key: continue handler_fn_inplace(key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True) return transformer -def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): +def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype): + init_kwargs = {"scaling_factor": scaling_factor} + if version == "1.5": + init_kwargs.update({"invert_scale_latents": True}) + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) + vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[:] @@ -199,6 +209,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): return vae +def get_transformer_init_kwargs(version: str): + if version == "1.0": + vae_scale_factor_spatial = 8 + init_kwargs = { + "patch_size": 2, + "patch_size_t": None, + "patch_bias": True, + "sample_height": 480 // vae_scale_factor_spatial, + "sample_width": 720 // vae_scale_factor_spatial, + "sample_frames": 49, + } + + elif version == "1.5": + vae_scale_factor_spatial = 8 + init_kwargs = { + "patch_size": 2, + "patch_size_t": 2, + "patch_bias": False, + "sample_height": 768 // vae_scale_factor_spatial, + "sample_width": 1360 // vae_scale_factor_spatial, + "sample_frames": 81, + } + else: + raise ValueError("Unsupported version of CogVideoX.") + + return init_kwargs + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -214,6 +252,12 @@ def get_args(): parser.add_argument( "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" ) + parser.add_argument( + "--typecast_text_encoder", + action="store_true", + default=False, + help="Whether or not to apply fp16/bf16 precision to text_encoder", + ) # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 @@ -226,7 +270,18 @@ def get_args(): parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") - parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument( + "--i2v", + action="store_true", + default=False, + help="Whether the model to be converted is the Image-to-Video version of CogVideoX.", + ) + parser.add_argument( + "--version", + choices=["1.0", "1.5"], + default="1.0", + help="Which version of CogVideoX to use for initializing default modeling parameters.", + ) return parser.parse_args() @@ -242,6 +297,7 @@ if __name__ == "__main__": dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 if args.transformer_ckpt_path is not None: + init_kwargs = get_transformer_init_kwargs(args.version) transformer = convert_transformer( args.transformer_ckpt_path, args.num_layers, @@ -249,14 +305,19 @@ if __name__ == "__main__": args.use_rotary_positional_embeddings, args.i2v, dtype, + init_kwargs, ) if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) + # Keep VAE in float32 for better quality + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32) - text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl" + text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + if args.typecast_text_encoder: + text_encoder = text_encoder.to(dtype=dtype) + # Apparently, the conversion does not work anymore without this :shrug: for param in text_encoder.parameters(): param.data = param.data.contiguous() @@ -288,11 +349,6 @@ if __name__ == "__main__": scheduler=scheduler, ) - if args.fp16: - pipe = pipe.to(dtype=torch.float16) - if args.bf16: - pipe = pipe.to(dtype=torch.bfloat16) - # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # for users to specify variant when the default is not fp32 and they want to run with the correct default (which # is either fp16/bf16 here).