mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-19 12:20:00 +08:00
Update dit_video_concat.py
This commit is contained in:
parent
25531820ea
commit
1c2e487820
@ -240,17 +240,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
|||||||
text_length,
|
text_length,
|
||||||
theta=10000,
|
theta=10000,
|
||||||
rot_v=False,
|
rot_v=False,
|
||||||
pnp=False,
|
|
||||||
learnable_pos_embed=False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rot_v = rot_v
|
self.rot_v = rot_v
|
||||||
|
self.text_length = text_length
|
||||||
|
|
||||||
dim_t = hidden_size_head // 4
|
dim_t = hidden_size_head // 4
|
||||||
dim_h = hidden_size_head // 8 * 3
|
dim_h = hidden_size_head // 8 * 3
|
||||||
dim_w = hidden_size_head // 8 * 3
|
dim_w = hidden_size_head // 8 * 3
|
||||||
|
|
||||||
# 'lang':
|
|
||||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||||
@ -268,11 +266,6 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
|||||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||||
|
|
||||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||||
# (T H W D)
|
|
||||||
|
|
||||||
self.pnp = pnp
|
|
||||||
|
|
||||||
if not self.pnp:
|
|
||||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||||
|
|
||||||
freqs = freqs.contiguous()
|
freqs = freqs.contiguous()
|
||||||
@ -281,39 +274,14 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
|||||||
self.register_buffer("freqs_sin", freqs_sin)
|
self.register_buffer("freqs_sin", freqs_sin)
|
||||||
self.register_buffer("freqs_cos", freqs_cos)
|
self.register_buffer("freqs_cos", freqs_cos)
|
||||||
|
|
||||||
self.text_length = text_length
|
|
||||||
if learnable_pos_embed:
|
|
||||||
num_patches = height * width * compressed_num_frames + text_length
|
|
||||||
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
|
||||||
else:
|
|
||||||
self.pos_embedding = None
|
|
||||||
|
|
||||||
def rotary(self, t, **kwargs):
|
def rotary(self, t, **kwargs):
|
||||||
if self.pnp:
|
seq_len = t.shape[2]
|
||||||
t_coords = kwargs["rope_position_ids"][:, :, 0]
|
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||||
x_coords = kwargs["rope_position_ids"][:, :, 1]
|
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||||
y_coords = kwargs["rope_position_ids"][:, :, 2]
|
|
||||||
mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
|
|
||||||
freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
|
|
||||||
freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def reshape_freq(freqs):
|
|
||||||
frame = t.shape[2]
|
|
||||||
freqs = freqs[:frame].contiguous()
|
|
||||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
|
||||||
return freqs
|
|
||||||
|
|
||||||
freqs_cos = reshape_freq(self.freqs_cos)
|
|
||||||
freqs_sin = reshape_freq(self.freqs_sin)
|
|
||||||
|
|
||||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||||
|
|
||||||
def position_embedding_forward(self, position_ids, **kwargs):
|
def position_embedding_forward(self, position_ids, **kwargs):
|
||||||
if self.pos_embedding is not None:
|
|
||||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def attention_fn(
|
def attention_fn(
|
||||||
@ -329,64 +297,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
|||||||
):
|
):
|
||||||
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
||||||
|
|
||||||
if self.pnp:
|
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
|
||||||
query_layer = self.rotary(query_layer, **kwargs)
|
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
|
||||||
key_layer = self.rotary(key_layer, **kwargs)
|
|
||||||
if self.rot_v:
|
if self.rot_v:
|
||||||
value_layer = self.rotary(value_layer)
|
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
|
||||||
else:
|
|
||||||
query_layer = torch.cat(
|
|
||||||
(
|
|
||||||
query_layer[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
: kwargs["text_length"],
|
|
||||||
],
|
|
||||||
self.rotary(
|
|
||||||
query_layer[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
kwargs["text_length"] :,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
key_layer = torch.cat(
|
|
||||||
(
|
|
||||||
key_layer[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
: kwargs["text_length"],
|
|
||||||
],
|
|
||||||
self.rotary(
|
|
||||||
key_layer[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
kwargs["text_length"] :,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
if self.rot_v:
|
|
||||||
value_layer = torch.cat(
|
|
||||||
(
|
|
||||||
value_layer[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
: kwargs["text_length"],
|
|
||||||
],
|
|
||||||
self.rotary(
|
|
||||||
value_layer[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
kwargs["text_length"] :,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
return attention_fn_default(
|
return attention_fn_default(
|
||||||
query_layer,
|
query_layer,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user