mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
[fix]fix deepspeed initialization issue in finetune examples
This commit is contained in:
parent
2fdc59c3ce
commit
ac2f2c78f7
@ -1242,11 +1242,11 @@ def main(args):
|
||||
|
||||
use_deepspeed_optimizer = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
|
||||
)
|
||||
use_deepspeed_scheduler = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
@ -1357,10 +1357,9 @@ def main(args):
|
||||
from accelerate.utils import DummyScheduler
|
||||
|
||||
lr_scheduler = DummyScheduler(
|
||||
name=args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = get_scheduler(
|
||||
|
@ -1154,11 +1154,11 @@ def main(args):
|
||||
|
||||
use_deepspeed_optimizer = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
|
||||
)
|
||||
use_deepspeed_scheduler = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
@ -1219,10 +1219,9 @@ def main(args):
|
||||
from accelerate.utils import DummyScheduler
|
||||
|
||||
lr_scheduler = DummyScheduler(
|
||||
name=args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = get_scheduler(
|
||||
|
Loading…
x
Reference in New Issue
Block a user