开启deepspeed时bfloat16类型错误

This commit is contained in:
luyulei.233 2024-09-29 21:20:49 +08:00
parent 4a2af29867
commit 096fd2eab3

View File

@ -1042,7 +1042,7 @@ def main(args):
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
):
weight_dtype = torch.float16
weight_dtype = torch.bfloat16
else:
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16