mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-19 03:55:53 +08:00
Update dit_video_concat.py
This commit is contained in:
parent
25531820ea
commit
1c2e487820
@ -240,17 +240,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
text_length,
|
||||
theta=10000,
|
||||
rot_v=False,
|
||||
pnp=False,
|
||||
learnable_pos_embed=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.rot_v = rot_v
|
||||
self.text_length = text_length
|
||||
|
||||
dim_t = hidden_size_head // 4
|
||||
dim_h = hidden_size_head // 8 * 3
|
||||
dim_w = hidden_size_head // 8 * 3
|
||||
|
||||
# 'lang':
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||
@ -268,12 +266,7 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
# (T H W D)
|
||||
|
||||
self.pnp = pnp
|
||||
|
||||
if not self.pnp:
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
|
||||
freqs = freqs.contiguous()
|
||||
freqs_sin = freqs.sin()
|
||||
@ -281,40 +274,15 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
self.register_buffer("freqs_sin", freqs_sin)
|
||||
self.register_buffer("freqs_cos", freqs_cos)
|
||||
|
||||
self.text_length = text_length
|
||||
if learnable_pos_embed:
|
||||
num_patches = height * width * compressed_num_frames + text_length
|
||||
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
def rotary(self, t, **kwargs):
|
||||
if self.pnp:
|
||||
t_coords = kwargs["rope_position_ids"][:, :, 0]
|
||||
x_coords = kwargs["rope_position_ids"][:, :, 1]
|
||||
y_coords = kwargs["rope_position_ids"][:, :, 2]
|
||||
mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
|
||||
freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
|
||||
freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
|
||||
|
||||
else:
|
||||
|
||||
def reshape_freq(freqs):
|
||||
frame = t.shape[2]
|
||||
freqs = freqs[:frame].contiguous()
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
||||
return freqs
|
||||
|
||||
freqs_cos = reshape_freq(self.freqs_cos)
|
||||
freqs_sin = reshape_freq(self.freqs_sin)
|
||||
seq_len = t.shape[2]
|
||||
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
|
||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
if self.pos_embedding is not None:
|
||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
def attention_fn(
|
||||
self,
|
||||
@ -329,64 +297,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
):
|
||||
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
||||
|
||||
if self.pnp:
|
||||
query_layer = self.rotary(query_layer, **kwargs)
|
||||
key_layer = self.rotary(key_layer, **kwargs)
|
||||
if self.rot_v:
|
||||
value_layer = self.rotary(value_layer)
|
||||
else:
|
||||
query_layer = torch.cat(
|
||||
(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
key_layer = torch.cat(
|
||||
(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
if self.rot_v:
|
||||
value_layer = torch.cat(
|
||||
(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
|
||||
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
|
||||
if self.rot_v:
|
||||
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
|
||||
|
||||
return attention_fn_default(
|
||||
query_layer,
|
||||
|
Loading…
x
Reference in New Issue
Block a user