mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 11:28:37 +08:00
update
This commit is contained in:
parent
5e3e3aabe0
commit
17996f11f8
@ -194,7 +194,7 @@ models we currently offer, along with their foundational information.
|
|||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td style="text-align: center;">Inference Precision</td>
|
<td style="text-align: center;">Inference Precision</td>
|
||||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||||
<td style="text-align: center;"><b>FP16*(Recommended)</b>, BF16, FP32, FP8*, INT8, Not supported: INT4</td>
|
<td style="text-align: center;"><b>FP16*(Recommended)</b>, BF16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||||
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
|
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
@ -186,7 +186,7 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
|
|||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td style="text-align: center;">推論精度</td>
|
<td style="text-align: center;">推論精度</td>
|
||||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
|
||||||
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32,FP8*,INT8,INT4非対応</td>
|
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32,FP8*,INT8,INT4非対応</td>
|
||||||
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
|
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
@ -176,7 +176,7 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
|
|||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td style="text-align: center;">推理精度</td>
|
<td style="text-align: center;">推理精度</td>
|
||||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
|
||||||
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32,FP8*,INT8,不支持INT4</td>
|
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32,FP8*,INT8,不支持INT4</td>
|
||||||
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
|
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
@ -103,16 +103,13 @@ def generate_video(
|
|||||||
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
||||||
# and enable to("cuda")
|
# and enable to("cuda")
|
||||||
|
|
||||||
pipe.to("cuda")
|
# pipe.to("cuda")
|
||||||
|
pipe.enable_sequential_cpu_offload()
|
||||||
# pipe.enable_sequential_cpu_offload()
|
|
||||||
|
|
||||||
pipe.vae.enable_slicing()
|
pipe.vae.enable_slicing()
|
||||||
pipe.vae.enable_tiling()
|
pipe.vae.enable_tiling()
|
||||||
|
|
||||||
# 4. Generate the video frames based on the prompt.
|
# 4. Generate the video frames based on the prompt.
|
||||||
# `num_frames` is the Number of frames to generate.
|
# `num_frames` is the Number of frames to generate.
|
||||||
# This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
|
|
||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
video_generate = pipe(
|
video_generate = pipe(
|
||||||
height=height,
|
height=height,
|
||||||
|
@ -92,6 +92,8 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
|||||||
"post_attn1_layernorm": "norm2.norm",
|
"post_attn1_layernorm": "norm2.norm",
|
||||||
"time_embed.0": "time_embedding.linear_1",
|
"time_embed.0": "time_embedding.linear_1",
|
||||||
"time_embed.2": "time_embedding.linear_2",
|
"time_embed.2": "time_embedding.linear_2",
|
||||||
|
"ofs_embed.0": "ofs_embedding.linear_1",
|
||||||
|
"ofs_embed.2": "ofs_embedding.linear_2",
|
||||||
"mixins.patch_embed": "patch_embed",
|
"mixins.patch_embed": "patch_embed",
|
||||||
"mixins.final_layer.norm_final": "norm_out.norm",
|
"mixins.final_layer.norm_final": "norm_out.norm",
|
||||||
"mixins.final_layer.linear": "proj_out",
|
"mixins.final_layer.linear": "proj_out",
|
||||||
@ -146,12 +148,13 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
|||||||
|
|
||||||
|
|
||||||
def convert_transformer(
|
def convert_transformer(
|
||||||
ckpt_path: str,
|
ckpt_path: str,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_attention_heads: int,
|
num_attention_heads: int,
|
||||||
use_rotary_positional_embeddings: bool,
|
use_rotary_positional_embeddings: bool,
|
||||||
i2v: bool,
|
i2v: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
init_kwargs: Dict[str, Any],
|
||||||
):
|
):
|
||||||
PREFIX_KEY = "model.diffusion_model."
|
PREFIX_KEY = "model.diffusion_model."
|
||||||
|
|
||||||
@ -161,11 +164,13 @@ def convert_transformer(
|
|||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||||
use_learned_positional_embeddings=i2v,
|
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
|
||||||
|
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
|
||||||
|
**init_kwargs,
|
||||||
).to(dtype=dtype)
|
).to(dtype=dtype)
|
||||||
|
|
||||||
for key in list(original_state_dict.keys()):
|
for key in list(original_state_dict.keys()):
|
||||||
new_key = key[len(PREFIX_KEY):]
|
new_key = key[len(PREFIX_KEY) :]
|
||||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||||
new_key = new_key.replace(replace_key, rename_key)
|
new_key = new_key.replace(replace_key, rename_key)
|
||||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||||
@ -175,13 +180,18 @@ def convert_transformer(
|
|||||||
if special_key not in key:
|
if special_key not in key:
|
||||||
continue
|
continue
|
||||||
handler_fn_inplace(key, original_state_dict)
|
handler_fn_inplace(key, original_state_dict)
|
||||||
|
|
||||||
transformer.load_state_dict(original_state_dict, strict=True)
|
transformer.load_state_dict(original_state_dict, strict=True)
|
||||||
return transformer
|
return transformer
|
||||||
|
|
||||||
|
|
||||||
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
|
||||||
|
init_kwargs = {"scaling_factor": scaling_factor}
|
||||||
|
if version == "1.5":
|
||||||
|
init_kwargs.update({"invert_scale_latents": True})
|
||||||
|
|
||||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||||
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
|
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
|
||||||
|
|
||||||
for key in list(original_state_dict.keys()):
|
for key in list(original_state_dict.keys()):
|
||||||
new_key = key[:]
|
new_key = key[:]
|
||||||
@ -199,6 +209,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
|||||||
return vae
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def get_transformer_init_kwargs(version: str):
|
||||||
|
if version == "1.0":
|
||||||
|
vae_scale_factor_spatial = 8
|
||||||
|
init_kwargs = {
|
||||||
|
"patch_size": 2,
|
||||||
|
"patch_size_t": None,
|
||||||
|
"patch_bias": True,
|
||||||
|
"sample_height": 480 // vae_scale_factor_spatial,
|
||||||
|
"sample_width": 720 // vae_scale_factor_spatial,
|
||||||
|
"sample_frames": 49,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif version == "1.5":
|
||||||
|
vae_scale_factor_spatial = 8
|
||||||
|
init_kwargs = {
|
||||||
|
"patch_size": 2,
|
||||||
|
"patch_size_t": 2,
|
||||||
|
"patch_bias": False,
|
||||||
|
"sample_height": 768 // vae_scale_factor_spatial,
|
||||||
|
"sample_width": 1360 // vae_scale_factor_spatial,
|
||||||
|
"sample_frames": 81,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported version of CogVideoX.")
|
||||||
|
|
||||||
|
return init_kwargs
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -214,6 +252,12 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--typecast_text_encoder",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether or not to apply fp16/bf16 precision to text_encoder",
|
||||||
|
)
|
||||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||||
@ -226,7 +270,18 @@ def get_args():
|
|||||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||||
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
parser.add_argument(
|
||||||
|
"--i2v",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--version",
|
||||||
|
choices=["1.0", "1.5"],
|
||||||
|
default="1.0",
|
||||||
|
help="Which version of CogVideoX to use for initializing default modeling parameters.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -242,6 +297,7 @@ if __name__ == "__main__":
|
|||||||
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
||||||
|
|
||||||
if args.transformer_ckpt_path is not None:
|
if args.transformer_ckpt_path is not None:
|
||||||
|
init_kwargs = get_transformer_init_kwargs(args.version)
|
||||||
transformer = convert_transformer(
|
transformer = convert_transformer(
|
||||||
args.transformer_ckpt_path,
|
args.transformer_ckpt_path,
|
||||||
args.num_layers,
|
args.num_layers,
|
||||||
@ -249,14 +305,19 @@ if __name__ == "__main__":
|
|||||||
args.use_rotary_positional_embeddings,
|
args.use_rotary_positional_embeddings,
|
||||||
args.i2v,
|
args.i2v,
|
||||||
dtype,
|
dtype,
|
||||||
|
init_kwargs,
|
||||||
)
|
)
|
||||||
if args.vae_ckpt_path is not None:
|
if args.vae_ckpt_path is not None:
|
||||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
# Keep VAE in float32 for better quality
|
||||||
|
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
|
||||||
|
|
||||||
text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
|
text_encoder_id = "google/t5-v1_1-xxl"
|
||||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||||
|
|
||||||
|
if args.typecast_text_encoder:
|
||||||
|
text_encoder = text_encoder.to(dtype=dtype)
|
||||||
|
|
||||||
# Apparently, the conversion does not work anymore without this :shrug:
|
# Apparently, the conversion does not work anymore without this :shrug:
|
||||||
for param in text_encoder.parameters():
|
for param in text_encoder.parameters():
|
||||||
param.data = param.data.contiguous()
|
param.data = param.data.contiguous()
|
||||||
@ -288,11 +349,6 @@ if __name__ == "__main__":
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.fp16:
|
|
||||||
pipe = pipe.to(dtype=torch.float16)
|
|
||||||
if args.bf16:
|
|
||||||
pipe = pipe.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||||
# is either fp16/bf16 here).
|
# is either fp16/bf16 here).
|
||||||
|
Loading…
x
Reference in New Issue
Block a user