mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-17 05:42:14 +08:00
chore: code cleanup and parameter optimization
- Remove redundant comments and debug information - Adjust default parameters in training scripts - Clean up code in lora_trainer and trainer implementations
This commit is contained in:
parent
954ba28d3c
commit
455b44a7b5
@ -197,7 +197,6 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
base_num_frames = num_frames
|
||||
else:
|
||||
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
||||
breakpoint()
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=transformer_config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
|
@ -49,7 +49,7 @@ SYSTEM_ARGS=(
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10
|
||||
--checkpointing_limit 2
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
|
@ -48,7 +48,7 @@ SYSTEM_ARGS=(
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10
|
||||
--checkpointing_limit 2
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
|
@ -758,16 +758,6 @@ class Trainer:
|
||||
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
||||
# if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
||||
# if must_save or global_step % self.args.checkpointing_steps == 0:
|
||||
# save_path = get_intermediate_ckpt_path(
|
||||
# checkpointing_limit=self.args.checkpointing_limit,
|
||||
# step=global_step,
|
||||
# output_dir=self.args.output_dir,
|
||||
# )
|
||||
# self.accelerator.save_state(save_path, safe_serialization=True)
|
||||
|
||||
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
||||
if must_save or global_step % self.args.checkpointing_steps == 0:
|
||||
@ -783,4 +773,4 @@ class Trainer:
|
||||
pipe_save_path.mkdir(parents=True, exist_ok=True)
|
||||
pipe.save_pretrained(pipe_save_path)
|
||||
del pipe
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
Loading…
x
Reference in New Issue
Block a user