fix(cogvideox): add prompt embedding caching and fix frame padding

- Add support for cached prompt embeddings in dataset
- Fix bug where first frame wasn't properly padded in latent space
This commit is contained in:
OleehyO 2025-01-03 09:27:33 +00:00
parent de5bef6611
commit 66e4ba2592

View File

@ -6,7 +6,6 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers import (
CogVideoXImageToVideoPipeline,
@ -22,6 +21,7 @@ from ..utils import register
class CogVideoXI2VLoraTrainer(Trainer):
UNLOAD_LIST = ["text_encoder"]
@override
def load_components(self) -> Dict[str, Any]:
@ -51,6 +51,17 @@ class CogVideoXI2VLoraTrainer(Trainer):
)
return components
@override
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
)
return pipe
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
@ -60,52 +71,67 @@ class CogVideoXI2VLoraTrainer(Trainer):
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def encode_text(self, prompt: str) -> torch.Tensor:
prompt_token_ids = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
return prompt_embedding
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_token_ids": [],
"prompt_embedding": [],
"images": []
}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt = sample["prompt"]
prompt_embedding = sample["prompt_embedding"]
image = sample["image"]
# tokenize prompt
text_inputs = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
ret["encoded_videos"].append(encoded_video)
ret["prompt_token_ids"].append(text_input_ids[0])
ret["prompt_embedding"].append(prompt_embedding)
ret["images"].append(image)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
ret["images"] = torch.stack(ret["images"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_token_ids = batch["prompt_token_ids"]
prompt_embedding = batch["prompt_embedding"]
latent = batch["encoded_videos"]
images = batch["images"]
# Shape of prompt_embedding: [B, seq_len, hidden_size]
# Shape of latent: [B, C, F, H, W]
# Shape of images: [B, C, H, W]
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None:
ncopy = latent.shape[2] % patch_size_t
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
assert latent.shape[2] % patch_size_t == 0
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
_, seq_len, _ = prompt_embedding.shape
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1)
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
images = images.unsqueeze(2)
@ -113,7 +139,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = self.components.vae.encode(noisy_images).latent_dist
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
# Sample a random timestep for each sample
@ -160,7 +186,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embedding,
timestep=timesteps,
ofs=ofs_emb,
image_rotary_emb=rotary_emb,
@ -182,7 +208,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
@override
def validation_step(
self, eval_data: Dict[str, Any]
self, eval_data: Dict[str, Any], pipe: CogVideoXImageToVideoPipeline
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
"""
Return the data that needs to be saved. For videos, the data format is List[PIL],
@ -190,13 +216,6 @@ class CogVideoXI2VLoraTrainer(Trainer):
"""
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
pipe = self.components.pipeline_cls(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
)
video_generate = pipe(
num_frames=self.state.train_frames,
height=self.state.train_height,