diff --git a/README.md b/README.md
index 73705cd..b94bf00 100644
--- a/README.md
+++ b/README.md
@@ -194,7 +194,7 @@ models we currently offer, along with their foundational information.
Inference Precision |
- BF16 |
+ BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported: INT4 |
FP16*(Recommended), BF16, FP32, FP8*, INT8, Not supported: INT4 |
BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported: INT4 |
diff --git a/README_ja.md b/README_ja.md
index 26b02c1..a7aa11b 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -186,7 +186,7 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
推論精度 |
- BF16 |
+ BF16(推奨), FP16, FP32,FP8*,INT8,INT4非対応 |
FP16*(推奨), BF16, FP32,FP8*,INT8,INT4非対応 |
BF16(推奨), FP16, FP32,FP8*,INT8,INT4非対応 |
diff --git a/README_zh.md b/README_zh.md
index f456376..704c467 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -176,7 +176,7 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
推理精度 |
- BF16 |
+ BF16(推荐), FP16, FP32,FP8*,INT8,不支持INT4 |
FP16*(推荐), BF16, FP32,FP8*,INT8,不支持INT4 |
BF16(推荐), FP16, FP32,FP8*,INT8,不支持INT4 |
diff --git a/inference/cli_demo.py b/inference/cli_demo.py
index bc97dd8..a211b4b 100644
--- a/inference/cli_demo.py
+++ b/inference/cli_demo.py
@@ -103,16 +103,13 @@ def generate_video(
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
# and enable to("cuda")
- pipe.to("cuda")
-
- # pipe.enable_sequential_cpu_offload()
-
+ # pipe.to("cuda")
+ pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# 4. Generate the video frames based on the prompt.
# `num_frames` is the Number of frames to generate.
- # This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
if generate_type == "i2v":
video_generate = pipe(
height=height,
diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py
index f325018..b70af1a 100644
--- a/tools/convert_weight_sat2hf.py
+++ b/tools/convert_weight_sat2hf.py
@@ -92,6 +92,8 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
+ "ofs_embed.0": "ofs_embedding.linear_1",
+ "ofs_embed.2": "ofs_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
@@ -146,12 +148,13 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer(
- ckpt_path: str,
- num_layers: int,
- num_attention_heads: int,
- use_rotary_positional_embeddings: bool,
- i2v: bool,
- dtype: torch.dtype,
+ ckpt_path: str,
+ num_layers: int,
+ num_attention_heads: int,
+ use_rotary_positional_embeddings: bool,
+ i2v: bool,
+ dtype: torch.dtype,
+ init_kwargs: Dict[str, Any],
):
PREFIX_KEY = "model.diffusion_model."
@@ -161,11 +164,13 @@ def convert_transformer(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
- use_learned_positional_embeddings=i2v,
+ ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
+ use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
+ **init_kwargs,
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
- new_key = key[len(PREFIX_KEY):]
+ new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -175,13 +180,18 @@ def convert_transformer(
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
+
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
-def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
+def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
+ init_kwargs = {"scaling_factor": scaling_factor}
+ if version == "1.5":
+ init_kwargs.update({"invert_scale_latents": True})
+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
- vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
+ vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -199,6 +209,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
return vae
+def get_transformer_init_kwargs(version: str):
+ if version == "1.0":
+ vae_scale_factor_spatial = 8
+ init_kwargs = {
+ "patch_size": 2,
+ "patch_size_t": None,
+ "patch_bias": True,
+ "sample_height": 480 // vae_scale_factor_spatial,
+ "sample_width": 720 // vae_scale_factor_spatial,
+ "sample_frames": 49,
+ }
+
+ elif version == "1.5":
+ vae_scale_factor_spatial = 8
+ init_kwargs = {
+ "patch_size": 2,
+ "patch_size_t": 2,
+ "patch_bias": False,
+ "sample_height": 768 // vae_scale_factor_spatial,
+ "sample_width": 1360 // vae_scale_factor_spatial,
+ "sample_frames": 81,
+ }
+ else:
+ raise ValueError("Unsupported version of CogVideoX.")
+
+ return init_kwargs
+
+
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -214,6 +252,12 @@ def get_args():
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
+ parser.add_argument(
+ "--typecast_text_encoder",
+ action="store_true",
+ default=False,
+ help="Whether or not to apply fp16/bf16 precision to text_encoder",
+ )
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
@@ -226,7 +270,18 @@ def get_args():
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
- parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
+ parser.add_argument(
+ "--i2v",
+ action="store_true",
+ default=False,
+ help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
+ )
+ parser.add_argument(
+ "--version",
+ choices=["1.0", "1.5"],
+ default="1.0",
+ help="Which version of CogVideoX to use for initializing default modeling parameters.",
+ )
return parser.parse_args()
@@ -242,6 +297,7 @@ if __name__ == "__main__":
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None:
+ init_kwargs = get_transformer_init_kwargs(args.version)
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
@@ -249,14 +305,19 @@ if __name__ == "__main__":
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
+ init_kwargs,
)
if args.vae_ckpt_path is not None:
- vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
+ # Keep VAE in float32 for better quality
+ vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
- text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
+ text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+ if args.typecast_text_encoder:
+ text_encoder = text_encoder.to(dtype=dtype)
+
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
@@ -288,11 +349,6 @@ if __name__ == "__main__":
scheduler=scheduler,
)
- if args.fp16:
- pipe = pipe.to(dtype=torch.float16)
- if args.bf16:
- pipe = pipe.to(dtype=torch.bfloat16)
-
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).