mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-09-21 22:00:00 +08:00
perf: cast VAE and text encoder to target dtype before precomputing cache
Before precomputing the latent cache and text embeddings, cast the VAE and text encoder to the target training dtype (fp16/bf16) instead of keeping them in fp32. This reduces memory usage during the precomputation phase. The change occurs in prepare_dataset() where the models are moved to device and cast to weight_dtype before being used to generate the cache.
This commit is contained in:
parent
0e21d41b12
commit
10de04fc08
@ -185,10 +185,12 @@ class Trainer:
|
||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||
|
||||
# Prepare VAE and text encoder for encoding
|
||||
self.components.vae = self.components.vae.to(self.accelerator.device)
|
||||
self.components.vae.requires_grad_(False)
|
||||
self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device)
|
||||
self.components.text_encoder.requires_grad_(False)
|
||||
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
self.components.text_encoder = self.components.text_encoder.to(
|
||||
self.accelerator.device, dtype=self.state.weight_dtype
|
||||
)
|
||||
|
||||
# Precompute latent for video and prompt embedding
|
||||
logger.info("Precomputing latent for video and prompt embedding ...")
|
||||
|
Loading…
x
Reference in New Issue
Block a user