diff --git a/finetune/trainer.py b/finetune/trainer.py index a6be790..5e2c40a 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -165,13 +165,13 @@ class Trainer: # self.state.train_frames includes one padding frame for image conditioning # so we only sample train_frames - 1 frames from the actual video - max_num_frames = self.state.train_frames - 1 + sample_frames = self.state.train_frames - 1 if self.args.model_type == "i2v": self.dataset = I2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - max_num_frames=max_num_frames, + max_num_frames=sample_frames, height=self.state.train_height, width=self.state.train_width, trainer=self @@ -180,7 +180,7 @@ class Trainer: self.dataset = T2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - max_num_frames=max_num_frames, + max_num_frames=sample_frames, height=self.state.train_height, width=self.state.train_width, trainer=self @@ -595,9 +595,12 @@ class Trainer: step=step, ) + pipe.remove_all_hooks() del pipe - # Unload loaded models except those needed for training + # Unload models except those needed for training self.__unload_components() + # Load models except those not needed for training + self.__load_components() # Change LoRA weights back to fp32 cast_training_params([self.components.transformer], dtype=torch.float32) @@ -610,7 +613,6 @@ class Trainer: torch.set_grad_enabled(True) self.components.transformer.train() - def fit(self): self.check_setting()