Merge pull request #567 from THUDM/main

New Finetune
This commit is contained in:
Yuxuan.Zhang 2024-12-02 11:30:20 +08:00 committed by GitHub
commit 87ccd38cea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 8 deletions

View File

@ -1246,11 +1246,11 @@ def main(args):
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
@ -1361,10 +1361,9 @@ def main(args):
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,
total_num_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
)
else:
lr_scheduler = get_scheduler(

View File

@ -1158,11 +1158,11 @@ def main(args):
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
@ -1223,10 +1223,9 @@ def main(args):
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,
total_num_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
)
else:
lr_scheduler = get_scheduler(