diff --git a/finetune/models/cogvideox_i2v/lora_trainer.py b/finetune/models/cogvideox_i2v/lora_trainer.py index 442f769..8b64497 100644 --- a/finetune/models/cogvideox_i2v/lora_trainer.py +++ b/finetune/models/cogvideox_i2v/lora_trainer.py @@ -6,7 +6,6 @@ from PIL import Image from transformers import AutoTokenizer, T5EncoderModel -from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers import ( CogVideoXImageToVideoPipeline, @@ -22,6 +21,7 @@ from ..utils import register class CogVideoXI2VLoraTrainer(Trainer): + UNLOAD_LIST = ["text_encoder"] @override def load_components(self) -> Dict[str, Any]: @@ -51,6 +51,17 @@ class CogVideoXI2VLoraTrainer(Trainer): ) return components + + @override + def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline: + pipe = CogVideoXImageToVideoPipeline( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=unwrap_model(self.accelerator, self.components.transformer), + scheduler=self.components.scheduler + ) + return pipe @override def encode_video(self, video: torch.Tensor) -> torch.Tensor: @@ -60,52 +71,67 @@ class CogVideoXI2VLoraTrainer(Trainer): latent_dist = vae.encode(video).latent_dist latent = latent_dist.sample() * vae.config.scaling_factor return latent + + @override + def encode_text(self, prompt: str) -> torch.Tensor: + prompt_token_ids = self.components.tokenizer( + prompt, + padding="max_length", + max_length=self.state.transformer_config.max_text_seq_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + prompt_token_ids = prompt_token_ids.input_ids + prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] + return prompt_embedding @override def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: ret = { "encoded_videos": [], - "prompt_token_ids": [], + "prompt_embedding": [], "images": [] } for sample in samples: encoded_video = sample["encoded_video"] - prompt = sample["prompt"] + prompt_embedding = sample["prompt_embedding"] image = sample["image"] - # tokenize prompt - text_inputs = self.components.tokenizer( - prompt, - padding="max_length", - max_length=self.state.transformer_config.max_text_seq_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids ret["encoded_videos"].append(encoded_video) - ret["prompt_token_ids"].append(text_input_ids[0]) + ret["prompt_embedding"].append(prompt_embedding) ret["images"].append(image) ret["encoded_videos"] = torch.stack(ret["encoded_videos"]) - ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"]) + ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"]) ret["images"] = torch.stack(ret["images"]) return ret @override def compute_loss(self, batch) -> torch.Tensor: - prompt_token_ids = batch["prompt_token_ids"] + prompt_embedding = batch["prompt_embedding"] latent = batch["encoded_videos"] images = batch["images"] + # Shape of prompt_embedding: [B, seq_len, hidden_size] + # Shape of latent: [B, C, F, H, W] + # Shape of images: [B, C, H, W] + + patch_size_t = self.state.transformer_config.patch_size_t + if patch_size_t is not None: + ncopy = latent.shape[2] % patch_size_t + # Copy the first frame ncopy times to match patch_size_t + first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W] + latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2) + assert latent.shape[2] % patch_size_t == 0 + batch_size, num_channels, num_frames, height, width = latent.shape # Get prompt embeddings - prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1) + _, seq_len, _ = prompt_embedding.shape + prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1) # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W] images = images.unsqueeze(2) @@ -113,7 +139,7 @@ class CogVideoXI2VLoraTrainer(Trainer): image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device) image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype) noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] - image_latent_dist = self.components.vae.encode(noisy_images).latent_dist + image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor # Sample a random timestep for each sample @@ -160,7 +186,7 @@ class CogVideoXI2VLoraTrainer(Trainer): ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0) predicted_noise = self.components.transformer( hidden_states=latent_img_noisy, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embedding, timestep=timesteps, ofs=ofs_emb, image_rotary_emb=rotary_emb, @@ -182,7 +208,7 @@ class CogVideoXI2VLoraTrainer(Trainer): @override def validation_step( - self, eval_data: Dict[str, Any] + self, eval_data: Dict[str, Any], pipe: CogVideoXImageToVideoPipeline ) -> List[Tuple[str, Image.Image | List[Image.Image]]]: """ Return the data that needs to be saved. For videos, the data format is List[PIL], @@ -190,13 +216,6 @@ class CogVideoXI2VLoraTrainer(Trainer): """ prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"] - pipe = self.components.pipeline_cls( - tokenizer=self.components.tokenizer, - text_encoder=self.components.text_encoder, - vae=self.components.vae, - transformer=unwrap_model(self.accelerator, self.components.transformer), - scheduler=self.components.scheduler - ) video_generate = pipe( num_frames=self.state.train_frames, height=self.state.train_height,