Fix for deepspeed training

This commit is contained in:
OleehyO 2024-12-12 13:38:25 +00:00
parent 4d1b9fd166
commit 36f1333788
3 changed files with 52 additions and 10 deletions

View File

@ -1,21 +1,17 @@
compute_environment: LOCAL_MACHINE
gpu_ids: "0"
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
deepspeed_config_file: ds_config.json
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
mixed_precision: 'no'
num_machines: 1
num_processes: 8
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []

20
finetune/ds_config.json Normal file
View File

@ -0,0 +1,20 @@
{
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e8,
"contiguous_gradients": true
}
}

View File

@ -948,6 +948,10 @@ def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
expected_midxed_precision = "bf16" if "5b" in args.pretrained_model_name_or_path.lower() else "fp16"
if args.mixed_precision != expected_midxed_precision:
raise ValueError(f"Mixed precision {args.mixed_precision} does not match the model precision, should be {expected_midxed_precision}")
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
@ -958,6 +962,28 @@ def main(args):
kwargs_handlers=[kwargs],
)
if accelerator.state.deepspeed_plugin:
# Set deepspeed config according to args
config = {
'optimizer': {
'type': args.optimizer,
'params': {
'lr': args.learning_rate,
'betas': [args.adam_beta1, args.adam_beta2]
},
'torch_adam': True
},
'bf16': {
'enabled': True if args.mixed_precision == "bf16" else False
},
'fp16': {
'enabled': True if args.mixed_precision == "fp16" else False
},
'gradient_accumulation_steps': args.gradient_accumulation_steps,
'train_batch_size': args.train_batch_size
}
accelerator.state.deepspeed_plugin.deepspeed_config.update(config)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
@ -1045,7 +1071,7 @@ def main(args):
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
):
weight_dtype = torch.float16
weight_dtype = torch.bfloat16
else:
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
@ -1158,11 +1184,11 @@ def main(args):
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)