diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py index b2fd9ce..3074872 100644 --- a/finetune/schemas/args.py +++ b/finetune/schemas/args.py @@ -1,5 +1,6 @@ import argparse import datetime +import logging from pathlib import Path from typing import Any, List, Literal, Tuple @@ -157,6 +158,15 @@ class Args(BaseModel): raise ValueError("train_resolution must be in format 'frames x height x width'") raise e + @field_validator("mixed_precision") + def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str: + if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower(): + logging.warning( + "All CogVideoX models except cogvideox-2b were trained with bfloat16. " + "Using fp16 precision may lead to training instability." + ) + return v + @classmethod def parse_args(cls): """Parse command line arguments and return Args instance"""