mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:43:38 +08:00
开启deepspeed时bfloat16类型错误
This commit is contained in:
parent
4a2af29867
commit
096fd2eab3
@ -1042,7 +1042,7 @@ def main(args):
|
||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||
):
|
||||
weight_dtype = torch.float16
|
||||
weight_dtype = torch.bfloat16
|
||||
else:
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user