[Fix] fix rope temporal patch size

This commit is contained in:
spacegoing 2024-11-21 16:21:30 +00:00
parent 2fdc59c3ce
commit 2fb763d25f
2 changed files with 12 additions and 2 deletions

View File

@ -912,6 +912,7 @@ def prepare_rotary_positional_embeddings(
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
patch_size_t: int = 1,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
@ -922,12 +923,15 @@ def prepare_rotary_positional_embeddings(
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
p_t = patch_size_t
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
)
freqs_cos = freqs_cos.to(device=device)
@ -1482,6 +1486,7 @@ def main(args):
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
patch_size_t=model_config.patch_size_t,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)

View File

@ -825,6 +825,7 @@ def prepare_rotary_positional_embeddings(
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
patch_size_t: int = 1,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
@ -835,12 +836,15 @@ def prepare_rotary_positional_embeddings(
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
p_t = patch_size_t
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
temporal_size=base_num_frames,
)
freqs_cos = freqs_cos.to(device=device)
@ -1346,6 +1350,7 @@ def main(args):
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
patch_size_t=model_config.patch_size_t,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)