mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
fix: remove pipeline hooks after validation
- Add pipe.remove_all_hooks() after validation to prevent memory leaks - Clean up validation pipeline properly to avoid potential issues in subsequent training steps
This commit is contained in:
parent
93b906b3fb
commit
49dc370de6
@ -165,13 +165,13 @@ class Trainer:
|
|||||||
|
|
||||||
# self.state.train_frames includes one padding frame for image conditioning
|
# self.state.train_frames includes one padding frame for image conditioning
|
||||||
# so we only sample train_frames - 1 frames from the actual video
|
# so we only sample train_frames - 1 frames from the actual video
|
||||||
max_num_frames = self.state.train_frames - 1
|
sample_frames = self.state.train_frames - 1
|
||||||
|
|
||||||
if self.args.model_type == "i2v":
|
if self.args.model_type == "i2v":
|
||||||
self.dataset = I2VDatasetWithResize(
|
self.dataset = I2VDatasetWithResize(
|
||||||
**(self.args.model_dump()),
|
**(self.args.model_dump()),
|
||||||
device=self.accelerator.device,
|
device=self.accelerator.device,
|
||||||
max_num_frames=max_num_frames,
|
max_num_frames=sample_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self
|
trainer=self
|
||||||
@ -180,7 +180,7 @@ class Trainer:
|
|||||||
self.dataset = T2VDatasetWithResize(
|
self.dataset = T2VDatasetWithResize(
|
||||||
**(self.args.model_dump()),
|
**(self.args.model_dump()),
|
||||||
device=self.accelerator.device,
|
device=self.accelerator.device,
|
||||||
max_num_frames=max_num_frames,
|
max_num_frames=sample_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
trainer=self
|
trainer=self
|
||||||
@ -595,9 +595,12 @@ class Trainer:
|
|||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pipe.remove_all_hooks()
|
||||||
del pipe
|
del pipe
|
||||||
# Unload loaded models except those needed for training
|
# Unload models except those needed for training
|
||||||
self.__unload_components()
|
self.__unload_components()
|
||||||
|
# Load models except those not needed for training
|
||||||
|
self.__load_components()
|
||||||
# Change LoRA weights back to fp32
|
# Change LoRA weights back to fp32
|
||||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||||
|
|
||||||
@ -610,7 +613,6 @@ class Trainer:
|
|||||||
|
|
||||||
torch.set_grad_enabled(True)
|
torch.set_grad_enabled(True)
|
||||||
self.components.transformer.train()
|
self.components.transformer.train()
|
||||||
|
|
||||||
|
|
||||||
def fit(self):
|
def fit(self):
|
||||||
self.check_setting()
|
self.check_setting()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user