mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Fix for deepspeed training
This commit is contained in:
parent
4d1b9fd166
commit
36f1333788
@ -1,21 +1,17 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
gpu_ids: "0"
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
deepspeed_config_file: ds_config.json
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
dynamo_backend: 'no'
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
20
finetune/ds_config.json
Normal file
20
finetune/ds_config.json
Normal file
@ -0,0 +1,20 @@
|
||||
{
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 1e8,
|
||||
"contiguous_gradients": true
|
||||
}
|
||||
}
|
@ -948,6 +948,10 @@ def main(args):
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
expected_midxed_precision = "bf16" if "5b" in args.pretrained_model_name_or_path.lower() else "fp16"
|
||||
if args.mixed_precision != expected_midxed_precision:
|
||||
raise ValueError(f"Mixed precision {args.mixed_precision} does not match the model precision, should be {expected_midxed_precision}")
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
@ -958,6 +962,28 @@ def main(args):
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
if accelerator.state.deepspeed_plugin:
|
||||
# Set deepspeed config according to args
|
||||
config = {
|
||||
'optimizer': {
|
||||
'type': args.optimizer,
|
||||
'params': {
|
||||
'lr': args.learning_rate,
|
||||
'betas': [args.adam_beta1, args.adam_beta2]
|
||||
},
|
||||
'torch_adam': True
|
||||
},
|
||||
'bf16': {
|
||||
'enabled': True if args.mixed_precision == "bf16" else False
|
||||
},
|
||||
'fp16': {
|
||||
'enabled': True if args.mixed_precision == "fp16" else False
|
||||
},
|
||||
'gradient_accumulation_steps': args.gradient_accumulation_steps,
|
||||
'train_batch_size': args.train_batch_size
|
||||
}
|
||||
accelerator.state.deepspeed_plugin.deepspeed_config.update(config)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
@ -1045,7 +1071,7 @@ def main(args):
|
||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||
):
|
||||
weight_dtype = torch.float16
|
||||
weight_dtype = torch.bfloat16
|
||||
else:
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
@ -1158,11 +1184,11 @@ def main(args):
|
||||
|
||||
use_deepspeed_optimizer = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
|
||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
use_deepspeed_scheduler = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
|
||||
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
|
Loading…
x
Reference in New Issue
Block a user