mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Fix for deepspeed training
This commit is contained in:
parent
4d1b9fd166
commit
36f1333788
@ -1,21 +1,17 @@
|
|||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
|
gpu_ids: "0"
|
||||||
debug: false
|
debug: false
|
||||||
deepspeed_config:
|
deepspeed_config:
|
||||||
gradient_accumulation_steps: 1
|
deepspeed_config_file: ds_config.json
|
||||||
gradient_clipping: 1.0
|
|
||||||
offload_optimizer_device: none
|
|
||||||
offload_param_device: none
|
|
||||||
zero3_init_flag: false
|
zero3_init_flag: false
|
||||||
zero_stage: 2
|
|
||||||
distributed_type: DEEPSPEED
|
distributed_type: DEEPSPEED
|
||||||
downcast_bf16: 'no'
|
downcast_bf16: 'no'
|
||||||
enable_cpu_affinity: false
|
enable_cpu_affinity: false
|
||||||
machine_rank: 0
|
machine_rank: 0
|
||||||
main_training_function: main
|
main_training_function: main
|
||||||
dynamo_backend: 'no'
|
dynamo_backend: 'no'
|
||||||
mixed_precision: 'no'
|
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 8
|
num_processes: 1
|
||||||
rdzv_backend: static
|
rdzv_backend: static
|
||||||
same_network: true
|
same_network: true
|
||||||
tpu_env: []
|
tpu_env: []
|
||||||
|
20
finetune/ds_config.json
Normal file
20
finetune/ds_config.json
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 2e8,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 1e8,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
}
|
||||||
|
}
|
@ -948,6 +948,10 @@ def main(args):
|
|||||||
|
|
||||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||||
|
|
||||||
|
expected_midxed_precision = "bf16" if "5b" in args.pretrained_model_name_or_path.lower() else "fp16"
|
||||||
|
if args.mixed_precision != expected_midxed_precision:
|
||||||
|
raise ValueError(f"Mixed precision {args.mixed_precision} does not match the model precision, should be {expected_midxed_precision}")
|
||||||
|
|
||||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||||
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
@ -958,6 +962,28 @@ def main(args):
|
|||||||
kwargs_handlers=[kwargs],
|
kwargs_handlers=[kwargs],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if accelerator.state.deepspeed_plugin:
|
||||||
|
# Set deepspeed config according to args
|
||||||
|
config = {
|
||||||
|
'optimizer': {
|
||||||
|
'type': args.optimizer,
|
||||||
|
'params': {
|
||||||
|
'lr': args.learning_rate,
|
||||||
|
'betas': [args.adam_beta1, args.adam_beta2]
|
||||||
|
},
|
||||||
|
'torch_adam': True
|
||||||
|
},
|
||||||
|
'bf16': {
|
||||||
|
'enabled': True if args.mixed_precision == "bf16" else False
|
||||||
|
},
|
||||||
|
'fp16': {
|
||||||
|
'enabled': True if args.mixed_precision == "fp16" else False
|
||||||
|
},
|
||||||
|
'gradient_accumulation_steps': args.gradient_accumulation_steps,
|
||||||
|
'train_batch_size': args.train_batch_size
|
||||||
|
}
|
||||||
|
accelerator.state.deepspeed_plugin.deepspeed_config.update(config)
|
||||||
|
|
||||||
# Disable AMP for MPS.
|
# Disable AMP for MPS.
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
accelerator.native_amp = False
|
accelerator.native_amp = False
|
||||||
@ -1045,7 +1071,7 @@ def main(args):
|
|||||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||||
):
|
):
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.float16
|
||||||
@ -1158,11 +1184,11 @@ def main(args):
|
|||||||
|
|
||||||
use_deepspeed_optimizer = (
|
use_deepspeed_optimizer = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
|
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
)
|
)
|
||||||
use_deepspeed_scheduler = (
|
use_deepspeed_scheduler = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
|
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user