From 1c2e487820e35ac7f53d2634b69d48c1811f236c Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Wed, 14 Aug 2024 22:09:03 +0800 Subject: [PATCH] Update dit_video_concat.py --- sat/dit_video_concat.py | 106 ++++------------------------------------ 1 file changed, 10 insertions(+), 96 deletions(-) diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index d42a92e..2a74114 100644 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -240,17 +240,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): text_length, theta=10000, rot_v=False, - pnp=False, - learnable_pos_embed=False, ): super().__init__() self.rot_v = rot_v + self.text_length = text_length dim_t = hidden_size_head // 4 dim_h = hidden_size_head // 8 * 3 dim_w = hidden_size_head // 8 * 3 - # 'lang': freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) @@ -268,12 +266,7 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - # (T H W D) - - self.pnp = pnp - - if not self.pnp: - freqs = rearrange(freqs, "t h w d -> (t h w) d") + freqs = rearrange(freqs, "t h w d -> (t h w) d") freqs = freqs.contiguous() freqs_sin = freqs.sin() @@ -281,40 +274,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): self.register_buffer("freqs_sin", freqs_sin) self.register_buffer("freqs_cos", freqs_cos) - self.text_length = text_length - if learnable_pos_embed: - num_patches = height * width * compressed_num_frames + text_length - self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True) - else: - self.pos_embedding = None - def rotary(self, t, **kwargs): - if self.pnp: - t_coords = kwargs["rope_position_ids"][:, :, 0] - x_coords = kwargs["rope_position_ids"][:, :, 1] - y_coords = kwargs["rope_position_ids"][:, :, 2] - mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1) - freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device) - freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]] - - else: - - def reshape_freq(freqs): - frame = t.shape[2] - freqs = freqs[:frame].contiguous() - freqs = freqs.unsqueeze(0).unsqueeze(0) - return freqs - - freqs_cos = reshape_freq(self.freqs_cos) - freqs_sin = reshape_freq(self.freqs_sin) + seq_len = t.shape[2] + freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) + freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) return t * freqs_cos + rotate_half(t) * freqs_sin def position_embedding_forward(self, position_ids, **kwargs): - if self.pos_embedding is not None: - return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]] - else: - return None + return None def attention_fn( self, @@ -329,64 +297,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): ): attention_fn_default = HOOKS_DEFAULT["attention_fn"] - if self.pnp: - query_layer = self.rotary(query_layer, **kwargs) - key_layer = self.rotary(key_layer, **kwargs) - if self.rot_v: - value_layer = self.rotary(value_layer) - else: - query_layer = torch.cat( - ( - query_layer[ - :, - :, - : kwargs["text_length"], - ], - self.rotary( - query_layer[ - :, - :, - kwargs["text_length"] :, - ] - ), - ), - dim=2, - ) - key_layer = torch.cat( - ( - key_layer[ - :, - :, - : kwargs["text_length"], - ], - self.rotary( - key_layer[ - :, - :, - kwargs["text_length"] :, - ] - ), - ), - dim=2, - ) - if self.rot_v: - value_layer = torch.cat( - ( - value_layer[ - :, - :, - : kwargs["text_length"], - ], - self.rotary( - value_layer[ - :, - :, - kwargs["text_length"] :, - ] - ), - ), - dim=2, - ) + query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) + key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) + if self.rot_v: + value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) return attention_fn_default( query_layer,