perf: cast VAE and text encoder to target dtype before precomputing cache

Before precomputing the latent cache and text embeddings, cast the VAE and
text encoder to the target training dtype (fp16/bf16) instead of keeping them
in fp32. This reduces memory usage during the precomputation phase.

The change occurs in prepare_dataset() where the models are moved to device
and cast to weight_dtype before being used to generate the cache.
This commit is contained in:
OleehyO 2025-01-08 01:38:13 +00:00
parent 0e21d41b12
commit 10de04fc08

View File

@ -185,10 +185,12 @@ class Trainer:
raise ValueError(f"Invalid model type: {self.args.model_type}") raise ValueError(f"Invalid model type: {self.args.model_type}")
# Prepare VAE and text encoder for encoding # Prepare VAE and text encoder for encoding
self.components.vae = self.components.vae.to(self.accelerator.device)
self.components.vae.requires_grad_(False) self.components.vae.requires_grad_(False)
self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device)
self.components.text_encoder.requires_grad_(False) self.components.text_encoder.requires_grad_(False)
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.components.text_encoder = self.components.text_encoder.to(
self.accelerator.device, dtype=self.state.weight_dtype
)
# Precompute latent for video and prompt embedding # Precompute latent for video and prompt embedding
logger.info("Precomputing latent for video and prompt embedding ...") logger.info("Precomputing latent for video and prompt embedding ...")