feat: add warning for fp16 mixed precision training

This commit is contained in:
OleehyO 2025-01-07 06:00:38 +00:00
parent 36427274d6
commit 96e511b413

View File

@ -1,5 +1,6 @@
import argparse import argparse
import datetime import datetime
import logging
from pathlib import Path from pathlib import Path
from typing import Any, List, Literal, Tuple from typing import Any, List, Literal, Tuple
@ -157,6 +158,15 @@ class Args(BaseModel):
raise ValueError("train_resolution must be in format 'frames x height x width'") raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e raise e
@field_validator("mixed_precision")
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str:
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower():
logging.warning(
"All CogVideoX models except cogvideox-2b were trained with bfloat16. "
"Using fp16 precision may lead to training instability."
)
return v
@classmethod @classmethod
def parse_args(cls): def parse_args(cls):
"""Parse command line arguments and return Args instance""" """Parse command line arguments and return Args instance"""