Update dit_video_concat.py

This commit is contained in:
zR 2024-08-14 22:09:03 +08:00
parent 25531820ea
commit 1c2e487820

View File

@ -240,17 +240,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
text_length,
theta=10000,
rot_v=False,
pnp=False,
learnable_pos_embed=False,
):
super().__init__()
self.rot_v = rot_v
self.text_length = text_length
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
# 'lang':
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
@ -268,11 +266,6 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
# (T H W D)
self.pnp = pnp
if not self.pnp:
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.contiguous()
@ -281,39 +274,14 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
self.register_buffer("freqs_sin", freqs_sin)
self.register_buffer("freqs_cos", freqs_cos)
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
else:
self.pos_embedding = None
def rotary(self, t, **kwargs):
if self.pnp:
t_coords = kwargs["rope_position_ids"][:, :, 0]
x_coords = kwargs["rope_position_ids"][:, :, 1]
y_coords = kwargs["rope_position_ids"][:, :, 2]
mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
else:
def reshape_freq(freqs):
frame = t.shape[2]
freqs = freqs[:frame].contiguous()
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos)
freqs_sin = reshape_freq(self.freqs_sin)
seq_len = t.shape[2]
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None:
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
else:
return None
def attention_fn(
@ -329,64 +297,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
):
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
if self.pnp:
query_layer = self.rotary(query_layer, **kwargs)
key_layer = self.rotary(key_layer, **kwargs)
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
if self.rot_v:
value_layer = self.rotary(value_layer)
else:
query_layer = torch.cat(
(
query_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
query_layer[
:,
:,
kwargs["text_length"] :,
]
),
),
dim=2,
)
key_layer = torch.cat(
(
key_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
key_layer[
:,
:,
kwargs["text_length"] :,
]
),
),
dim=2,
)
if self.rot_v:
value_layer = torch.cat(
(
value_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
value_layer[
:,
:,
kwargs["text_length"] :,
]
),
),
dim=2,
)
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
return attention_fn_default(
query_layer,