mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-17 13:59:17 +08:00
chore: code cleanup and parameter optimization
- Remove redundant comments and debug information - Adjust default parameters in training scripts - Clean up code in lora_trainer and trainer implementations
This commit is contained in:
parent
954ba28d3c
commit
455b44a7b5
@ -197,7 +197,6 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
base_num_frames = num_frames
|
base_num_frames = num_frames
|
||||||
else:
|
else:
|
||||||
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
||||||
breakpoint()
|
|
||||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||||
embed_dim=transformer_config.attention_head_dim,
|
embed_dim=transformer_config.attention_head_dim,
|
||||||
crops_coords=None,
|
crops_coords=None,
|
||||||
|
@ -49,7 +49,7 @@ SYSTEM_ARGS=(
|
|||||||
CHECKPOINT_ARGS=(
|
CHECKPOINT_ARGS=(
|
||||||
--checkpointing_steps 10
|
--checkpointing_steps 10
|
||||||
--checkpointing_limit 2
|
--checkpointing_limit 2
|
||||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validation Configuration
|
# Validation Configuration
|
||||||
|
@ -48,7 +48,7 @@ SYSTEM_ARGS=(
|
|||||||
CHECKPOINT_ARGS=(
|
CHECKPOINT_ARGS=(
|
||||||
--checkpointing_steps 10
|
--checkpointing_steps 10
|
||||||
--checkpointing_limit 2
|
--checkpointing_limit 2
|
||||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validation Configuration
|
# Validation Configuration
|
||||||
|
@ -758,16 +758,6 @@ class Trainer:
|
|||||||
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||||
|
|
||||||
# def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
|
||||||
# if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
|
||||||
# if must_save or global_step % self.args.checkpointing_steps == 0:
|
|
||||||
# save_path = get_intermediate_ckpt_path(
|
|
||||||
# checkpointing_limit=self.args.checkpointing_limit,
|
|
||||||
# step=global_step,
|
|
||||||
# output_dir=self.args.output_dir,
|
|
||||||
# )
|
|
||||||
# self.accelerator.save_state(save_path, safe_serialization=True)
|
|
||||||
|
|
||||||
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
||||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
||||||
if must_save or global_step % self.args.checkpointing_steps == 0:
|
if must_save or global_step % self.args.checkpointing_steps == 0:
|
||||||
@ -783,4 +773,4 @@ class Trainer:
|
|||||||
pipe_save_path.mkdir(parents=True, exist_ok=True)
|
pipe_save_path.mkdir(parents=True, exist_ok=True)
|
||||||
pipe.save_pretrained(pipe_save_path)
|
pipe.save_pretrained(pipe_save_path)
|
||||||
del pipe
|
del pipe
|
||||||
torch.cuda.empty_cache()
|
free_memory()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user