diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py index 137f322..5ad4750 100644 --- a/finetune/train_cogvideox_lora.py +++ b/finetune/train_cogvideox_lora.py @@ -1042,7 +1042,7 @@ def main(args): "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] ): - weight_dtype = torch.float16 + weight_dtype = torch.bfloat16 else: if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16