diff --git a/finetune/trainer.py b/finetune/trainer.py index 981c500..53fc193 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -185,10 +185,12 @@ class Trainer: raise ValueError(f"Invalid model type: {self.args.model_type}") # Prepare VAE and text encoder for encoding - self.components.vae = self.components.vae.to(self.accelerator.device) self.components.vae.requires_grad_(False) - self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device) self.components.text_encoder.requires_grad_(False) + self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) + self.components.text_encoder = self.components.text_encoder.to( + self.accelerator.device, dtype=self.state.weight_dtype + ) # Precompute latent for video and prompt embedding logger.info("Precomputing latent for video and prompt embedding ...")