feat: add warning for fp16 mixed precision training

This commit is contained in:
OleehyO 2025-01-07 06:00:38 +00:00
parent 36427274d6
commit 96e511b413

View File

@ -1,5 +1,6 @@
import argparse
import datetime
import logging
from pathlib import Path
from typing import Any, List, Literal, Tuple
@ -157,6 +158,15 @@ class Args(BaseModel):
raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e
@field_validator("mixed_precision")
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str:
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower():
logging.warning(
"All CogVideoX models except cogvideox-2b were trained with bfloat16. "
"Using fp16 precision may lead to training instability."
)
return v
@classmethod
def parse_args(cls):
"""Parse command line arguments and return Args instance"""