mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
fix: remove pipeline hooks after validation
- Add pipe.remove_all_hooks() after validation to prevent memory leaks - Clean up validation pipeline properly to avoid potential issues in subsequent training steps
This commit is contained in:
parent
93b906b3fb
commit
49dc370de6
@ -165,13 +165,13 @@ class Trainer:
|
||||
|
||||
# self.state.train_frames includes one padding frame for image conditioning
|
||||
# so we only sample train_frames - 1 frames from the actual video
|
||||
max_num_frames = self.state.train_frames - 1
|
||||
sample_frames = self.state.train_frames - 1
|
||||
|
||||
if self.args.model_type == "i2v":
|
||||
self.dataset = I2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=max_num_frames,
|
||||
max_num_frames=sample_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self
|
||||
@ -180,7 +180,7 @@ class Trainer:
|
||||
self.dataset = T2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=max_num_frames,
|
||||
max_num_frames=sample_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self
|
||||
@ -595,9 +595,12 @@ class Trainer:
|
||||
step=step,
|
||||
)
|
||||
|
||||
pipe.remove_all_hooks()
|
||||
del pipe
|
||||
# Unload loaded models except those needed for training
|
||||
# Unload models except those needed for training
|
||||
self.__unload_components()
|
||||
# Load models except those not needed for training
|
||||
self.__load_components()
|
||||
# Change LoRA weights back to fp32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
@ -610,7 +613,6 @@ class Trainer:
|
||||
|
||||
torch.set_grad_enabled(True)
|
||||
self.components.transformer.train()
|
||||
|
||||
|
||||
def fit(self):
|
||||
self.check_setting()
|
||||
|
Loading…
x
Reference in New Issue
Block a user