From 10de04fc083f8f05ec5c2241a526f598d65c3827 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 8 Jan 2025 01:38:13 +0000 Subject: [PATCH] perf: cast VAE and text encoder to target dtype before precomputing cache Before precomputing the latent cache and text embeddings, cast the VAE and text encoder to the target training dtype (fp16/bf16) instead of keeping them in fp32. This reduces memory usage during the precomputation phase. The change occurs in prepare_dataset() where the models are moved to device and cast to weight_dtype before being used to generate the cache. --- finetune/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/finetune/trainer.py b/finetune/trainer.py index 981c500..53fc193 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -185,10 +185,12 @@ class Trainer: raise ValueError(f"Invalid model type: {self.args.model_type}") # Prepare VAE and text encoder for encoding - self.components.vae = self.components.vae.to(self.accelerator.device) self.components.vae.requires_grad_(False) - self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device) self.components.text_encoder.requires_grad_(False) + self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) + self.components.text_encoder = self.components.text_encoder.to( + self.accelerator.device, dtype=self.state.weight_dtype + ) # Precompute latent for video and prompt embedding logger.info("Precomputing latent for video and prompt embedding ...")