mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:52:08 +08:00
开启deepspeed时bfloat16类型错误
This commit is contained in:
parent
4a2af29867
commit
096fd2eab3
@ -1042,7 +1042,7 @@ def main(args):
|
|||||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||||
):
|
):
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.float16
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user