fix: remove pipeline hooks after validation

- Add pipe.remove_all_hooks() after validation to prevent memory leaks
- Clean up validation pipeline properly to avoid potential issues in subsequent training steps
This commit is contained in:
OleehyO 2025-01-04 06:21:17 +00:00
parent 93b906b3fb
commit 49dc370de6

View File

@ -165,13 +165,13 @@ class Trainer:
# self.state.train_frames includes one padding frame for image conditioning
# so we only sample train_frames - 1 frames from the actual video
max_num_frames = self.state.train_frames - 1
sample_frames = self.state.train_frames - 1
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
max_num_frames=max_num_frames,
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self
@ -180,7 +180,7 @@ class Trainer:
self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
max_num_frames=max_num_frames,
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width,
trainer=self
@ -595,9 +595,12 @@ class Trainer:
step=step,
)
pipe.remove_all_hooks()
del pipe
# Unload loaded models except those needed for training
# Unload models except those needed for training
self.__unload_components()
# Load models except those not needed for training
self.__load_components()
# Change LoRA weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
@ -610,7 +613,6 @@ class Trainer:
torch.set_grad_enabled(True)
self.components.transformer.train()
def fit(self):
self.check_setting()