mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
feat(cogvideox): add prompt embedding caching support
This change enables caching of prompt embeddings in the CogVideoX text-to-video LoRA trainer, which can improve training efficiency by avoiding redundant text encoding operations.
This commit is contained in:
parent
66e4ba2592
commit
7e1ac76847
@ -7,7 +7,6 @@ from PIL import Image
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, T5EncoderModel
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
|
|
||||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
|
||||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
CogVideoXPipeline,
|
CogVideoXPipeline,
|
||||||
@ -23,6 +22,7 @@ from ..utils import register
|
|||||||
|
|
||||||
|
|
||||||
class CogVideoXT2VLoraTrainer(Trainer):
|
class CogVideoXT2VLoraTrainer(Trainer):
|
||||||
|
UNLOAD_LIST = ["text_encoder", "vae"]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def load_components(self) -> Components:
|
def load_components(self) -> Components:
|
||||||
@ -52,6 +52,17 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
@override
|
||||||
|
def initialize_pipeline(self) -> CogVideoXPipeline:
|
||||||
|
pipe = CogVideoXPipeline(
|
||||||
|
tokenizer=self.components.tokenizer,
|
||||||
|
text_encoder=self.components.text_encoder,
|
||||||
|
vae=self.components.vae,
|
||||||
|
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||||
|
scheduler=self.components.scheduler
|
||||||
|
)
|
||||||
|
return pipe
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||||
@ -61,49 +72,57 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
latent_dist = vae.encode(video).latent_dist
|
latent_dist = vae.encode(video).latent_dist
|
||||||
latent = latent_dist.sample() * vae.config.scaling_factor
|
latent = latent_dist.sample() * vae.config.scaling_factor
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
@override
|
||||||
|
def encode_text(self, prompt: str) -> torch.Tensor:
|
||||||
|
prompt_token_ids = self.components.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.state.transformer_config.max_text_seq_length,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
prompt_token_ids = prompt_token_ids.input_ids
|
||||||
|
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
||||||
|
return prompt_embedding
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
ret = {
|
ret = {
|
||||||
"encoded_videos": [],
|
"encoded_videos": [],
|
||||||
"prompt_token_ids": []
|
"prompt_embedding": []
|
||||||
}
|
}
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
encoded_video = sample["encoded_video"]
|
encoded_video = sample["encoded_video"]
|
||||||
prompt = sample["prompt"]
|
prompt_embedding = sample["prompt_embedding"]
|
||||||
|
|
||||||
# tokenize prompt
|
|
||||||
text_inputs = self.components.tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=226,
|
|
||||||
truncation=True,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
|
|
||||||
ret["encoded_videos"].append(encoded_video)
|
ret["encoded_videos"].append(encoded_video)
|
||||||
ret["prompt_token_ids"].append(text_input_ids[0])
|
ret["prompt_embedding"].append(prompt_embedding)
|
||||||
|
|
||||||
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
||||||
ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
|
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, batch) -> torch.Tensor:
|
def compute_loss(self, batch) -> torch.Tensor:
|
||||||
prompt_token_ids = batch["prompt_token_ids"]
|
prompt_embedding = batch["prompt_embedding"]
|
||||||
latent = batch["encoded_videos"]
|
latent = batch["encoded_videos"]
|
||||||
|
|
||||||
|
# Shape of prompt_embedding: [B, seq_len, hidden_size]
|
||||||
|
# Shape of latent: [B, C, F, H, W]
|
||||||
|
|
||||||
|
patch_size_t = self.state.transformer_config.patch_size_t
|
||||||
|
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
||||||
|
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
|
||||||
|
|
||||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||||
|
|
||||||
# Get prompt embeddings
|
# Get prompt embeddings
|
||||||
prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
_, seq_len, _ = prompt_embedding.shape
|
||||||
_, seq_len, _ = prompt_embeds.shape
|
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1)
|
||||||
prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
|
|
||||||
assert prompt_embeds.requires_grad is False
|
|
||||||
|
|
||||||
# Sample a random timestep for each sample
|
# Sample a random timestep for each sample
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
@ -113,7 +132,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
# Add noise to latent
|
# Add noise to latent
|
||||||
latent = latent.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
latent = latent.permute(0, 2, 1, 3, 4) # from [B, C, F, H, W] to [B, F, C, H, W]
|
||||||
noise = torch.randn_like(latent)
|
noise = torch.randn_like(latent)
|
||||||
latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps)
|
latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps)
|
||||||
|
|
||||||
@ -136,7 +155,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
# Predict noise
|
# Predict noise
|
||||||
predicted_noise = self.components.transformer(
|
predicted_noise = self.components.transformer(
|
||||||
hidden_states=latent_added_noise,
|
hidden_states=latent_added_noise,
|
||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embedding,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
image_rotary_emb=rotary_emb,
|
image_rotary_emb=rotary_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
@ -157,7 +176,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def validation_step(
|
def validation_step(
|
||||||
self, eval_data: Dict[str, Any]
|
self, eval_data: Dict[str, Any], pipe: CogVideoXPipeline
|
||||||
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
||||||
"""
|
"""
|
||||||
Return the data that needs to be saved. For videos, the data format is List[PIL],
|
Return the data that needs to be saved. For videos, the data format is List[PIL],
|
||||||
@ -165,15 +184,8 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
||||||
|
|
||||||
pipe = self.components.pipeline_cls(
|
|
||||||
tokenizer=self.components.tokenizer,
|
|
||||||
text_encoder=self.components.text_encoder,
|
|
||||||
vae=self.components.vae,
|
|
||||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
|
||||||
scheduler=self.components.scheduler
|
|
||||||
)
|
|
||||||
video_generate = pipe(
|
video_generate = pipe(
|
||||||
num_frames=self.state.train_frames,
|
num_frames=self.state.train_frames - 1, # -1 is because t2v does not require adding an image frame like i2v does
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user