Merge pull request #599 from THUDM/CogVideoX_dev

Cog video x dev
This commit is contained in:
Yuxuan.Zhang 2024-12-13 15:03:48 +08:00 committed by GitHub
commit 1605e95033
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 58 additions and 16 deletions

View File

@ -1,21 +1,17 @@
compute_environment: LOCAL_MACHINE
gpu_ids: "0"
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
deepspeed_config_file: ds_config.json
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
mixed_precision: 'no'
num_machines: 1
num_processes: 8
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []

20
finetune/ds_config.json Normal file
View File

@ -0,0 +1,20 @@
{
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e8,
"contiguous_gradients": true
}
}

View File

@ -16,7 +16,7 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--caption_column prompt.txt \
--video_column videos.txt \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \

View File

@ -1,6 +1,6 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-2b"
export MODEL_PATH="THUDM/CogVideoX-5b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-single-node"
@ -8,7 +8,7 @@ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
accelerate launch --config_file accelerate_config_machine_single.yaml \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
@ -16,7 +16,7 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompts.txt \
--caption_column prompt.txt \
--video_column videos.txt \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \

View File

@ -948,6 +948,10 @@ def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
expected_midxed_precision = "bf16" if "5b" in args.pretrained_model_name_or_path.lower() else "fp16"
if args.mixed_precision != expected_midxed_precision:
raise ValueError(f"Mixed precision {args.mixed_precision} does not match the model precision, should be {expected_midxed_precision}")
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
@ -958,6 +962,28 @@ def main(args):
kwargs_handlers=[kwargs],
)
if accelerator.state.deepspeed_plugin:
# Set deepspeed config according to args
config = {
'optimizer': {
'type': args.optimizer,
'params': {
'lr': args.learning_rate,
'betas': [args.adam_beta1, args.adam_beta2]
},
'torch_adam': True
},
'bf16': {
'enabled': True if args.mixed_precision == "bf16" else False
},
'fp16': {
'enabled': True if args.mixed_precision == "fp16" else False
},
'gradient_accumulation_steps': args.gradient_accumulation_steps,
'train_batch_size': args.train_batch_size
}
accelerator.state.deepspeed_plugin.deepspeed_config.update(config)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
@ -1045,7 +1071,7 @@ def main(args):
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
):
weight_dtype = torch.float16
weight_dtype = torch.bfloat16
else:
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
@ -1158,11 +1184,11 @@ def main(args):
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
@ -1349,7 +1375,7 @@ def main(args):
num_frames=num_frames,
vae_scale_factor_spatial=vae_scale_factor_spatial,
patch_size=model_config.patch_size,
patch_size_t=model_config.patch_size_t,
patch_size_t=model_config.patch_size_t if model_config.patch_size_t is not None else 1,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
)

View File

@ -1,6 +1,6 @@
# SAT CogVideoX
[Read this in English.](./README_zh.md)
[中文阅读](./README_zh.md)
[日本語で読む](./README_ja.md)