mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
[fix]fix deepspeed initialization issue in finetune examples
This commit is contained in:
parent
2fdc59c3ce
commit
ac2f2c78f7
@ -1242,11 +1242,11 @@ def main(args):
|
|||||||
|
|
||||||
use_deepspeed_optimizer = (
|
use_deepspeed_optimizer = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
|
||||||
)
|
)
|
||||||
use_deepspeed_scheduler = (
|
use_deepspeed_scheduler = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||||
@ -1357,10 +1357,9 @@ def main(args):
|
|||||||
from accelerate.utils import DummyScheduler
|
from accelerate.utils import DummyScheduler
|
||||||
|
|
||||||
lr_scheduler = DummyScheduler(
|
lr_scheduler = DummyScheduler(
|
||||||
name=args.lr_scheduler,
|
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
||||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
|
@ -1154,11 +1154,11 @@ def main(args):
|
|||||||
|
|
||||||
use_deepspeed_optimizer = (
|
use_deepspeed_optimizer = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
|
||||||
)
|
)
|
||||||
use_deepspeed_scheduler = (
|
use_deepspeed_scheduler = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||||
@ -1219,10 +1219,9 @@ def main(args):
|
|||||||
from accelerate.utils import DummyScheduler
|
from accelerate.utils import DummyScheduler
|
||||||
|
|
||||||
lr_scheduler = DummyScheduler(
|
lr_scheduler = DummyScheduler(
|
||||||
name=args.lr_scheduler,
|
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
||||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user