mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
[Fix] fix rope temporal patch size
This commit is contained in:
parent
2fdc59c3ce
commit
2fb763d25f
@ -912,6 +912,7 @@ def prepare_rotary_positional_embeddings(
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
vae_scale_factor_spatial: int = 8,
|
vae_scale_factor_spatial: int = 8,
|
||||||
patch_size: int = 2,
|
patch_size: int = 2,
|
||||||
|
patch_size_t: int = 1,
|
||||||
attention_head_dim: int = 64,
|
attention_head_dim: int = 64,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
base_height: int = 480,
|
base_height: int = 480,
|
||||||
@ -922,12 +923,15 @@ def prepare_rotary_positional_embeddings(
|
|||||||
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
||||||
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
||||||
|
|
||||||
|
p_t = patch_size_t
|
||||||
|
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||||
|
|
||||||
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
|
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
|
||||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||||
embed_dim=attention_head_dim,
|
embed_dim=attention_head_dim,
|
||||||
crops_coords=grid_crops_coords,
|
crops_coords=grid_crops_coords,
|
||||||
grid_size=(grid_height, grid_width),
|
grid_size=(grid_height, grid_width),
|
||||||
temporal_size=num_frames,
|
temporal_size=base_num_frames,
|
||||||
)
|
)
|
||||||
|
|
||||||
freqs_cos = freqs_cos.to(device=device)
|
freqs_cos = freqs_cos.to(device=device)
|
||||||
@ -1482,6 +1486,7 @@ def main(args):
|
|||||||
num_frames=num_frames,
|
num_frames=num_frames,
|
||||||
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
||||||
patch_size=model_config.patch_size,
|
patch_size=model_config.patch_size,
|
||||||
|
patch_size_t=model_config.patch_size_t,
|
||||||
attention_head_dim=model_config.attention_head_dim,
|
attention_head_dim=model_config.attention_head_dim,
|
||||||
device=accelerator.device,
|
device=accelerator.device,
|
||||||
)
|
)
|
||||||
|
@ -825,6 +825,7 @@ def prepare_rotary_positional_embeddings(
|
|||||||
num_frames: int,
|
num_frames: int,
|
||||||
vae_scale_factor_spatial: int = 8,
|
vae_scale_factor_spatial: int = 8,
|
||||||
patch_size: int = 2,
|
patch_size: int = 2,
|
||||||
|
patch_size_t: int = 1,
|
||||||
attention_head_dim: int = 64,
|
attention_head_dim: int = 64,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
base_height: int = 480,
|
base_height: int = 480,
|
||||||
@ -835,12 +836,15 @@ def prepare_rotary_positional_embeddings(
|
|||||||
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
||||||
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
||||||
|
|
||||||
|
p_t = patch_size_t
|
||||||
|
base_num_frames = (num_frames + p_t - 1) // p_t
|
||||||
|
|
||||||
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
|
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
|
||||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||||
embed_dim=attention_head_dim,
|
embed_dim=attention_head_dim,
|
||||||
crops_coords=grid_crops_coords,
|
crops_coords=grid_crops_coords,
|
||||||
grid_size=(grid_height, grid_width),
|
grid_size=(grid_height, grid_width),
|
||||||
temporal_size=num_frames,
|
temporal_size=base_num_frames,
|
||||||
)
|
)
|
||||||
|
|
||||||
freqs_cos = freqs_cos.to(device=device)
|
freqs_cos = freqs_cos.to(device=device)
|
||||||
@ -1346,6 +1350,7 @@ def main(args):
|
|||||||
num_frames=num_frames,
|
num_frames=num_frames,
|
||||||
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
||||||
patch_size=model_config.patch_size,
|
patch_size=model_config.patch_size,
|
||||||
|
patch_size_t=model_config.patch_size_t,
|
||||||
attention_head_dim=model_config.attention_head_dim,
|
attention_head_dim=model_config.attention_head_dim,
|
||||||
device=accelerator.device,
|
device=accelerator.device,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user